참고
Go to the end to download the full example code.
파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기
모델 저장하고 불러오기#
이번 장에서는 저장하기나 불러오기를 통해 모델의 상태를 유지(persist)하고 모델의 예측을 실행하는 방법을 알아보겠습니다.
import torch
import torchvision.models as models
모델 가중치 저장하고 불러오기#
PyTorch 모델은 학습한 매개변수를 state_dict
라고 불리는 내부 상태 사전(internal state dictionary)에 저장합니다.
이 상태 값들은 torch.save
메소드를 사용하여 저장(persist)할 수 있습니다:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
1%| | 6.00M/528M [00:00<00:08, 62.6MB/s]
3%|▎ | 14.8M/528M [00:00<00:06, 79.4MB/s]
5%|▍ | 24.0M/528M [00:00<00:06, 87.1MB/s]
6%|▋ | 33.1M/528M [00:00<00:05, 88.0MB/s]
8%|▊ | 43.0M/528M [00:00<00:05, 93.3MB/s]
10%|▉ | 52.1M/528M [00:00<00:05, 93.9MB/s]
12%|█▏ | 61.4M/528M [00:00<00:05, 93.0MB/s]
14%|█▎ | 71.6M/528M [00:00<00:04, 97.3MB/s]
15%|█▌ | 81.2M/528M [00:00<00:04, 97.5MB/s]
17%|█▋ | 91.5M/528M [00:01<00:04, 100MB/s]
19%|█▉ | 102M/528M [00:01<00:04, 104MB/s]
21%|██ | 112M/528M [00:01<00:04, 104MB/s]
23%|██▎ | 122M/528M [00:01<00:04, 103MB/s]
25%|██▌ | 132M/528M [00:01<00:04, 102MB/s]
27%|██▋ | 142M/528M [00:01<00:04, 100MB/s]
29%|██▊ | 151M/528M [00:01<00:04, 98.5MB/s]
30%|███ | 161M/528M [00:01<00:03, 97.9MB/s]
32%|███▏ | 170M/528M [00:01<00:03, 97.3MB/s]
34%|███▍ | 180M/528M [00:01<00:03, 96.1MB/s]
36%|███▌ | 189M/528M [00:02<00:03, 95.4MB/s]
38%|███▊ | 198M/528M [00:02<00:03, 94.5MB/s]
39%|███▉ | 207M/528M [00:02<00:03, 94.4MB/s]
41%|████ | 216M/528M [00:02<00:03, 93.2MB/s]
43%|████▎ | 225M/528M [00:02<00:03, 93.8MB/s]
44%|████▍ | 234M/528M [00:02<00:03, 93.0MB/s]
46%|████▌ | 243M/528M [00:02<00:03, 90.7MB/s]
48%|████▊ | 253M/528M [00:02<00:03, 93.3MB/s]
50%|████▉ | 262M/528M [00:02<00:02, 94.0MB/s]
51%|█████▏ | 271M/528M [00:02<00:02, 94.4MB/s]
54%|█████▎ | 282M/528M [00:03<00:02, 101MB/s]
55%|█████▌ | 293M/528M [00:03<00:02, 103MB/s]
57%|█████▋ | 303M/528M [00:03<00:02, 106MB/s]
59%|█████▉ | 314M/528M [00:03<00:02, 106MB/s]
61%|██████▏ | 324M/528M [00:03<00:01, 107MB/s]
63%|██████▎ | 334M/528M [00:03<00:01, 105MB/s]
65%|██████▌ | 346M/528M [00:03<00:01, 109MB/s]
67%|██████▋ | 356M/528M [00:03<00:01, 107MB/s]
69%|██████▉ | 366M/528M [00:03<00:01, 109MB/s]
71%|███████▏ | 377M/528M [00:03<00:01, 108MB/s]
73%|███████▎ | 387M/528M [00:04<00:01, 107MB/s]
75%|███████▌ | 398M/528M [00:04<00:01, 109MB/s]
77%|███████▋ | 409M/528M [00:04<00:01, 109MB/s]
79%|███████▉ | 419M/528M [00:04<00:01, 108MB/s]
81%|████████▏ | 430M/528M [00:04<00:00, 105MB/s]
84%|████████▎ | 441M/528M [00:04<00:00, 110MB/s]
86%|████████▌ | 452M/528M [00:04<00:00, 112MB/s]
88%|████████▊ | 464M/528M [00:04<00:00, 113MB/s]
90%|████████▉ | 475M/528M [00:04<00:00, 115MB/s]
92%|█████████▏| 486M/528M [00:05<00:00, 114MB/s]
94%|█████████▍| 498M/528M [00:05<00:00, 116MB/s]
96%|█████████▋| 509M/528M [00:05<00:00, 117MB/s]
98%|█████████▊| 520M/528M [00:05<00:00, 116MB/s]
100%|██████████| 528M/528M [00:05<00:00, 102MB/s]
모델 가중치를 불러오기 위해서는, 먼저 동일한 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict()
메소드를 사용하여
매개변수들을 불러옵니다.
model = models.vgg16() # 여기서는 ``weights`` 를 지정하지 않았으므로, 학습되지 않은 모델을 생성합니다.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
참고
추론(inference)을 하기 전에 model.eval()
메소드를 호출하여 드롭아웃(dropout)과 배치 정규화(batch normalization)를 평가 모드(evaluation mode)로 설정해야 합니다. 그렇지 않으면 일관성 없는 추론 결과가 생성됩니다.
모델의 형태를 포함하여 저장하고 불러오기#
모델의 가중치를 불러올 때, 신경망의 구조를 정의하기 위해 모델 클래스를 먼저 생성(instantiate)해야 했습니다.
이 클래스의 구조를 모델과 함께 저장하고 싶으면, (model.state_dict()
가 아닌) model
을 저장 함수에
전달합니다:
torch.save(model, 'model.pth')
다음과 같이 모델을 불러올 수 있습니다:
model = torch.load('model.pth')
Traceback (most recent call last):
File "/workspace/tutorials-kr/beginner_source/basics/saveloadrun_tutorial.py", line 56, in <module>
model = torch.load('model.pth')
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1529, in load
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m.
(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL torchvision.models.vgg.VGG was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchvision.models.vgg.VGG])` or the `torch.serialization.safe_globals([torchvision.models.vgg.VGG])` context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
참고
이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화(serialize)하므로, 모델을 불러올 때 실제 클래스 정의(definition)를 적용(rely on)합니다.
관련 튜토리얼#
Saving And Loading A General Checkpoint Tips for Loading an nn.Module from a Checkpoint
Total running time of the script: (0 minutes 9.040 seconds)