Shortcuts

파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기

모델 저장하고 불러오기

이번 장에서는 저장하기나 불러오기를 통해 모델의 상태를 유지(persist)하고 모델의 예측을 실행하는 방법을 알아보겠습니다.

import torch
import torchvision.models as models

모델 가중치 저장하고 불러오기

PyTorch 모델은 학습한 매개변수를 state_dict라고 불리는 내부 상태 사전(internal state dictionary)에 저장합니다. 이 상태 값들은 torch.save 메소드를 사용하여 저장(persist)할 수 있습니다:

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  2%|1         | 8.38M/528M [00:00<00:06, 87.4MB/s]
  4%|3         | 18.9M/528M [00:00<00:05, 100MB/s]
  6%|5         | 29.2M/528M [00:00<00:05, 104MB/s]
  8%|7         | 40.5M/528M [00:00<00:04, 109MB/s]
 10%|9         | 50.9M/528M [00:00<00:04, 105MB/s]
 12%|#1        | 62.4M/528M [00:00<00:04, 110MB/s]
 14%|#3        | 72.9M/528M [00:00<00:04, 109MB/s]
 16%|#5        | 83.4M/528M [00:00<00:04, 107MB/s]
 18%|#8        | 95.5M/528M [00:00<00:04, 113MB/s]
 20%|##        | 107M/528M [00:01<00:03, 114MB/s]
 22%|##2       | 118M/528M [00:01<00:03, 114MB/s]
 24%|##4       | 129M/528M [00:01<00:03, 110MB/s]
 27%|##6       | 141M/528M [00:01<00:03, 112MB/s]
 29%|##8       | 153M/528M [00:01<00:03, 115MB/s]
 31%|###1      | 164M/528M [00:01<00:03, 111MB/s]
 33%|###3      | 176M/528M [00:01<00:03, 116MB/s]
 36%|###5      | 188M/528M [00:01<00:03, 114MB/s]
 38%|###7      | 198M/528M [00:01<00:03, 113MB/s]
 40%|###9      | 210M/528M [00:01<00:02, 116MB/s]
 42%|####1     | 222M/528M [00:02<00:02, 116MB/s]
 44%|####4     | 233M/528M [00:02<00:02, 116MB/s]
 46%|####6     | 244M/528M [00:02<00:02, 114MB/s]
 48%|####8     | 255M/528M [00:02<00:02, 116MB/s]
 50%|#####     | 266M/528M [00:02<00:02, 116MB/s]
 53%|#####2    | 278M/528M [00:02<00:02, 113MB/s]
 55%|#####4    | 289M/528M [00:02<00:02, 114MB/s]
 57%|#####7    | 301M/528M [00:02<00:02, 118MB/s]
 59%|#####9    | 312M/528M [00:02<00:01, 115MB/s]
 61%|######1   | 324M/528M [00:03<00:01, 115MB/s]
 64%|######3   | 336M/528M [00:03<00:01, 117MB/s]
 66%|######5   | 347M/528M [00:03<00:01, 117MB/s]
 68%|######7   | 358M/528M [00:03<00:01, 113MB/s]
 70%|#######   | 370M/528M [00:03<00:01, 114MB/s]
 72%|#######2  | 382M/528M [00:03<00:01, 118MB/s]
 75%|#######4  | 393M/528M [00:03<00:01, 117MB/s]
 77%|#######6  | 405M/528M [00:03<00:01, 116MB/s]
 79%|#######8  | 416M/528M [00:03<00:01, 116MB/s]
 81%|########  | 427M/528M [00:03<00:00, 115MB/s]
 83%|########2 | 438M/528M [00:04<00:00, 111MB/s]
 85%|########5 | 449M/528M [00:04<00:00, 113MB/s]
 87%|########7 | 461M/528M [00:04<00:00, 117MB/s]
 90%|########9 | 473M/528M [00:04<00:00, 116MB/s]
 92%|#########1| 484M/528M [00:04<00:00, 115MB/s]
 94%|#########3| 495M/528M [00:04<00:00, 115MB/s]
 96%|#########5| 506M/528M [00:04<00:00, 115MB/s]
 98%|#########7| 517M/528M [00:04<00:00, 111MB/s]
100%|##########| 528M/528M [00:04<00:00, 114MB/s]

모델 가중치를 불러오기 위해서는, 먼저 동일한 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict() 메소드를 사용하여 매개변수들을 불러옵니다.

model = models.vgg16() # 여기서는 ``weights`` 를 지정하지 않았으므로, 학습되지 않은 모델을 생성합니다.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

참고

추론(inference)을 하기 전에 model.eval() 메소드를 호출하여 드롭아웃(dropout)과 배치 정규화(batch normalization)를 평가 모드(evaluation mode)로 설정해야 합니다. 그렇지 않으면 일관성 없는 추론 결과가 생성됩니다.

모델의 형태를 포함하여 저장하고 불러오기

모델의 가중치를 불러올 때, 신경망의 구조를 정의하기 위해 모델 클래스를 먼저 생성(instantiate)해야 했습니다. 이 클래스의 구조를 모델과 함께 저장하고 싶으면, (model.state_dict()가 아닌) model 을 저장 함수에 전달합니다:

torch.save(model, 'model.pth')

다음과 같이 모델을 불러올 수 있습니다:

model = torch.load('model.pth')

참고

이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화(serialize)하므로, 모델을 불러올 때 실제 클래스 정의(definition)를 적용(rely on)합니다.


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동