Rate this Page

파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기

모델 저장하고 불러오기#

이번 장에서는 저장하기나 불러오기를 통해 모델의 상태를 유지(persist)하고 모델의 예측을 실행하는 방법을 알아보겠습니다.

import torch
import torchvision.models as models

모델 가중치 저장하고 불러오기#

PyTorch 모델은 학습한 매개변수를 state_dict라고 불리는 내부 상태 사전(internal state dictionary)에 저장합니다. 이 상태 값들은 torch.save 메소드를 사용하여 저장(persist)할 수 있습니다:

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  1%|          | 6.00M/528M [00:00<00:08, 62.6MB/s]
  3%|▎         | 14.8M/528M [00:00<00:06, 79.4MB/s]
  5%|▍         | 24.0M/528M [00:00<00:06, 87.1MB/s]
  6%|▋         | 33.1M/528M [00:00<00:05, 88.0MB/s]
  8%|▊         | 43.0M/528M [00:00<00:05, 93.3MB/s]
 10%|▉         | 52.1M/528M [00:00<00:05, 93.9MB/s]
 12%|█▏        | 61.4M/528M [00:00<00:05, 93.0MB/s]
 14%|█▎        | 71.6M/528M [00:00<00:04, 97.3MB/s]
 15%|█▌        | 81.2M/528M [00:00<00:04, 97.5MB/s]
 17%|█▋        | 91.5M/528M [00:01<00:04, 100MB/s]
 19%|█▉        | 102M/528M [00:01<00:04, 104MB/s]
 21%|██        | 112M/528M [00:01<00:04, 104MB/s]
 23%|██▎       | 122M/528M [00:01<00:04, 103MB/s]
 25%|██▌       | 132M/528M [00:01<00:04, 102MB/s]
 27%|██▋       | 142M/528M [00:01<00:04, 100MB/s]
 29%|██▊       | 151M/528M [00:01<00:04, 98.5MB/s]
 30%|███       | 161M/528M [00:01<00:03, 97.9MB/s]
 32%|███▏      | 170M/528M [00:01<00:03, 97.3MB/s]
 34%|███▍      | 180M/528M [00:01<00:03, 96.1MB/s]
 36%|███▌      | 189M/528M [00:02<00:03, 95.4MB/s]
 38%|███▊      | 198M/528M [00:02<00:03, 94.5MB/s]
 39%|███▉      | 207M/528M [00:02<00:03, 94.4MB/s]
 41%|████      | 216M/528M [00:02<00:03, 93.2MB/s]
 43%|████▎     | 225M/528M [00:02<00:03, 93.8MB/s]
 44%|████▍     | 234M/528M [00:02<00:03, 93.0MB/s]
 46%|████▌     | 243M/528M [00:02<00:03, 90.7MB/s]
 48%|████▊     | 253M/528M [00:02<00:03, 93.3MB/s]
 50%|████▉     | 262M/528M [00:02<00:02, 94.0MB/s]
 51%|█████▏    | 271M/528M [00:02<00:02, 94.4MB/s]
 54%|█████▎    | 282M/528M [00:03<00:02, 101MB/s]
 55%|█████▌    | 293M/528M [00:03<00:02, 103MB/s]
 57%|█████▋    | 303M/528M [00:03<00:02, 106MB/s]
 59%|█████▉    | 314M/528M [00:03<00:02, 106MB/s]
 61%|██████▏   | 324M/528M [00:03<00:01, 107MB/s]
 63%|██████▎   | 334M/528M [00:03<00:01, 105MB/s]
 65%|██████▌   | 346M/528M [00:03<00:01, 109MB/s]
 67%|██████▋   | 356M/528M [00:03<00:01, 107MB/s]
 69%|██████▉   | 366M/528M [00:03<00:01, 109MB/s]
 71%|███████▏  | 377M/528M [00:03<00:01, 108MB/s]
 73%|███████▎  | 387M/528M [00:04<00:01, 107MB/s]
 75%|███████▌  | 398M/528M [00:04<00:01, 109MB/s]
 77%|███████▋  | 409M/528M [00:04<00:01, 109MB/s]
 79%|███████▉  | 419M/528M [00:04<00:01, 108MB/s]
 81%|████████▏ | 430M/528M [00:04<00:00, 105MB/s]
 84%|████████▎ | 441M/528M [00:04<00:00, 110MB/s]
 86%|████████▌ | 452M/528M [00:04<00:00, 112MB/s]
 88%|████████▊ | 464M/528M [00:04<00:00, 113MB/s]
 90%|████████▉ | 475M/528M [00:04<00:00, 115MB/s]
 92%|█████████▏| 486M/528M [00:05<00:00, 114MB/s]
 94%|█████████▍| 498M/528M [00:05<00:00, 116MB/s]
 96%|█████████▋| 509M/528M [00:05<00:00, 117MB/s]
 98%|█████████▊| 520M/528M [00:05<00:00, 116MB/s]
100%|██████████| 528M/528M [00:05<00:00, 102MB/s]

모델 가중치를 불러오기 위해서는, 먼저 동일한 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict() 메소드를 사용하여 매개변수들을 불러옵니다.

model = models.vgg16() # 여기서는 ``weights`` 를 지정하지 않았으므로, 학습되지 않은 모델을 생성합니다.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

참고

추론(inference)을 하기 전에 model.eval() 메소드를 호출하여 드롭아웃(dropout)과 배치 정규화(batch normalization)를 평가 모드(evaluation mode)로 설정해야 합니다. 그렇지 않으면 일관성 없는 추론 결과가 생성됩니다.

모델의 형태를 포함하여 저장하고 불러오기#

모델의 가중치를 불러올 때, 신경망의 구조를 정의하기 위해 모델 클래스를 먼저 생성(instantiate)해야 했습니다. 이 클래스의 구조를 모델과 함께 저장하고 싶으면, (model.state_dict()가 아닌) model 을 저장 함수에 전달합니다:

torch.save(model, 'model.pth')

다음과 같이 모델을 불러올 수 있습니다:

model = torch.load('model.pth')
Traceback (most recent call last):
  File "/workspace/tutorials-kr/beginner_source/basics/saveloadrun_tutorial.py", line 56, in <module>
    model = torch.load('model.pth')
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1529, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
    (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
    WeightsUnpickler error: Unsupported global: GLOBAL torchvision.models.vgg.VGG was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchvision.models.vgg.VGG])` or the `torch.serialization.safe_globals([torchvision.models.vgg.VGG])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

참고

이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화(serialize)하므로, 모델을 불러올 때 실제 클래스 정의(definition)를 적용(rely on)합니다.

관련 튜토리얼#

Saving And Loading A General Checkpoint Tips for Loading an nn.Module from a Checkpoint

Total running time of the script: (0 minutes 9.040 seconds)