참고
Click here to download the full example code
파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기
모델 저장하고 불러오기¶
이번 장에서는 저장하기나 불러오기를 통해 모델의 상태를 유지(persist)하고 모델의 예측을 실행하는 방법을 알아보겠습니다.
import torch
import torchvision.models as models
모델 가중치 저장하고 불러오기¶
PyTorch 모델은 학습한 매개변수를 state_dict
라고 불리는 내부 상태 사전(internal state dictionary)에 저장합니다.
이 상태 값들은 torch.save
메소드를 사용하여 저장(persist)할 수 있습니다:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
2%|1 | 8.38M/528M [00:00<00:06, 87.4MB/s]
4%|3 | 18.9M/528M [00:00<00:05, 100MB/s]
6%|5 | 29.2M/528M [00:00<00:05, 104MB/s]
8%|7 | 40.5M/528M [00:00<00:04, 109MB/s]
10%|9 | 50.9M/528M [00:00<00:04, 105MB/s]
12%|#1 | 62.4M/528M [00:00<00:04, 110MB/s]
14%|#3 | 72.9M/528M [00:00<00:04, 109MB/s]
16%|#5 | 83.4M/528M [00:00<00:04, 107MB/s]
18%|#8 | 95.5M/528M [00:00<00:04, 113MB/s]
20%|## | 107M/528M [00:01<00:03, 114MB/s]
22%|##2 | 118M/528M [00:01<00:03, 114MB/s]
24%|##4 | 129M/528M [00:01<00:03, 110MB/s]
27%|##6 | 141M/528M [00:01<00:03, 112MB/s]
29%|##8 | 153M/528M [00:01<00:03, 115MB/s]
31%|###1 | 164M/528M [00:01<00:03, 111MB/s]
33%|###3 | 176M/528M [00:01<00:03, 116MB/s]
36%|###5 | 188M/528M [00:01<00:03, 114MB/s]
38%|###7 | 198M/528M [00:01<00:03, 113MB/s]
40%|###9 | 210M/528M [00:01<00:02, 116MB/s]
42%|####1 | 222M/528M [00:02<00:02, 116MB/s]
44%|####4 | 233M/528M [00:02<00:02, 116MB/s]
46%|####6 | 244M/528M [00:02<00:02, 114MB/s]
48%|####8 | 255M/528M [00:02<00:02, 116MB/s]
50%|##### | 266M/528M [00:02<00:02, 116MB/s]
53%|#####2 | 278M/528M [00:02<00:02, 113MB/s]
55%|#####4 | 289M/528M [00:02<00:02, 114MB/s]
57%|#####7 | 301M/528M [00:02<00:02, 118MB/s]
59%|#####9 | 312M/528M [00:02<00:01, 115MB/s]
61%|######1 | 324M/528M [00:03<00:01, 115MB/s]
64%|######3 | 336M/528M [00:03<00:01, 117MB/s]
66%|######5 | 347M/528M [00:03<00:01, 117MB/s]
68%|######7 | 358M/528M [00:03<00:01, 113MB/s]
70%|####### | 370M/528M [00:03<00:01, 114MB/s]
72%|#######2 | 382M/528M [00:03<00:01, 118MB/s]
75%|#######4 | 393M/528M [00:03<00:01, 117MB/s]
77%|#######6 | 405M/528M [00:03<00:01, 116MB/s]
79%|#######8 | 416M/528M [00:03<00:01, 116MB/s]
81%|######## | 427M/528M [00:03<00:00, 115MB/s]
83%|########2 | 438M/528M [00:04<00:00, 111MB/s]
85%|########5 | 449M/528M [00:04<00:00, 113MB/s]
87%|########7 | 461M/528M [00:04<00:00, 117MB/s]
90%|########9 | 473M/528M [00:04<00:00, 116MB/s]
92%|#########1| 484M/528M [00:04<00:00, 115MB/s]
94%|#########3| 495M/528M [00:04<00:00, 115MB/s]
96%|#########5| 506M/528M [00:04<00:00, 115MB/s]
98%|#########7| 517M/528M [00:04<00:00, 111MB/s]
100%|##########| 528M/528M [00:04<00:00, 114MB/s]
모델 가중치를 불러오기 위해서는, 먼저 동일한 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict()
메소드를 사용하여
매개변수들을 불러옵니다.
model = models.vgg16() # 여기서는 ``weights`` 를 지정하지 않았으므로, 학습되지 않은 모델을 생성합니다.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
참고
추론(inference)을 하기 전에 model.eval()
메소드를 호출하여 드롭아웃(dropout)과 배치 정규화(batch normalization)를 평가 모드(evaluation mode)로 설정해야 합니다. 그렇지 않으면 일관성 없는 추론 결과가 생성됩니다.
모델의 형태를 포함하여 저장하고 불러오기¶
모델의 가중치를 불러올 때, 신경망의 구조를 정의하기 위해 모델 클래스를 먼저 생성(instantiate)해야 했습니다.
이 클래스의 구조를 모델과 함께 저장하고 싶으면, (model.state_dict()
가 아닌) model
을 저장 함수에
전달합니다:
torch.save(model, 'model.pth')
다음과 같이 모델을 불러올 수 있습니다:
model = torch.load('model.pth')
참고
이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화(serialize)하므로, 모델을 불러올 때 실제 클래스 정의(definition)를 적용(rely on)합니다.
관련 튜토리얼¶
PyTorch에서 일반적인 체크포인트(checkpoint) 저장하기 & 불러오기 Tips for Loading an nn.Module from a Checkpoint
Total running time of the script: ( 0 minutes 13.053 seconds)