참고
Click here to download the full example code
파이토치(PyTorch) 기본 익히기 || 빠른 시작 || 텐서(Tensor) || Dataset과 Dataloader || 변형(Transform) || 신경망 모델 구성하기 || Autograd || 최적화(Optimization) || 모델 저장하고 불러오기
모델 매개변수 최적화하기¶
이제 모델과 데이터가 준비되었으니, 데이터에 매개변수를 최적화하여 모델을 학습하고, 검증하고, 테스트할 차례입니다. 모델을 학습하는 과정은 반복적인 과정을 거칩니다; 각 반복 단계에서 모델은 출력을 추측하고, 추측과 정답 사이의 오류(손실(loss))를 계산하고, (이전 장에서 본 것처럼) 매개변수에 대한 오류의 도함수(derivative)를 수집한 뒤, 경사하강법을 사용하여 이 파라미터들을 최적화(optimize)합니다. 이 과정에 대한 자세한 설명은 3Blue1Brown의 역전파 영상을 참고하세요.
기본(Pre-requisite) 코드¶
이전 장인 Dataset과 DataLoader와 신경망 모델 구성하기에서 코드를 가져왔습니다.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
하이퍼파라미터(Hyperparameter)¶
하이퍼파라미터(Hyperparameter)는 모델 최적화 과정을 제어할 수 있는 조절 가능한 매개변수입니다. 서로 다른 하이퍼파라미터 값은 모델 학습과 수렴율(convergence rate)에 영향을 미칠 수 있습니다. (하이퍼파라미터 튜닝(tuning)에 대해 더 알아보기)
- 학습 시에는 다음과 같은 하이퍼파라미터를 정의합니다:
에폭(epoch) 수 - 데이터셋을 반복하는 횟수
배치 크기(batch size) - 매개변수가 갱신되기 전 신경망을 통해 전파된 데이터 샘플의 수
학습률(learning rate) - 각 배치/에폭에서 모델의 매개변수를 조절하는 비율. 값이 작을수록 학습 속도가 느려지고, 값이 크면 학습 중 예측할 수 없는 동작이 발생할 수 있습니다.
learning_rate = 1e-3
batch_size = 64
epochs = 5
최적화 단계(Optimization Loop)¶
하이퍼파라미터를 설정한 뒤에는 최적화 단계를 통해 모델을 학습하고 최적화할 수 있습니다. 최적화 단계의 각 반복(iteration)을 에폭이라고 부릅니다.
- 하나의 에폭은 다음 두 부분으로 구성됩니다:
학습 단계(train loop) - 학습용 데이터셋을 반복(iterate)하고 최적의 매개변수로 수렴합니다.
검증/테스트 단계(validation/test loop) - 모델 성능이 개선되고 있는지를 확인하기 위해 테스트 데이터셋을 반복(iterate)합니다.
학습 단계(training loop)에서 일어나는 몇 가지 개념들을 간략히 살펴보겠습니다. 최적화 단계(optimization loop)를 보려면 전체 구현 부분으로 건너뛰시면 됩니다.
손실 함수(loss function)¶
학습용 데이터를 제공하면, 학습되지 않은 신경망은 정답을 제공하지 않을 확률이 높습니다. 손실 함수(loss function)는 획득한 결과와 실제 값 사이의 틀린 정도(degree of dissimilarity)를 측정하며, 학습 중에 이 값을 최소화하려고 합니다. 주어진 데이터 샘플을 입력으로 계산한 예측과 정답(label)을 비교하여 손실(loss)을 계산합니다.
일반적인 손실함수에는 회귀 문제(regression task)에 사용하는 nn.MSELoss(평균 제곱 오차(MSE; Mean Square Error))나
분류(classification)에 사용하는 nn.NLLLoss (음의 로그 우도(Negative Log Likelihood)),
그리고 nn.LogSoftmax
와 nn.NLLLoss
를 합친 nn.CrossEntropyLoss
등이 있습니다.
모델의 출력 로짓(logit)을 nn.CrossEntropyLoss
에 전달하여 로짓(logit)을 정규화하고 예측 오류를 계산합니다.
# 손실 함수를 초기화합니다.
loss_fn = nn.CrossEntropyLoss()
옵티마이저(Optimizer)¶
최적화는 각 학습 단계에서 모델의 오류를 줄이기 위해 모델 매개변수를 조정하는 과정입니다. 최적화 알고리즘은 이 과정이 수행되는 방식(여기에서는 확률적 경사하강법(SGD; Stochastic Gradient Descent))을 정의합니다.
모든 최적화 절차(logic)는 optimizer
객체에 캡슐화(encapsulate)됩니다. 여기서는 SGD 옵티마이저를 사용하고 있으며, PyTorch에는 ADAM이나 RMSProp과 같은 다른 종류의 모델과 데이터에서 더 잘 동작하는
다양한 옵티마이저가 있습니다.
학습하려는 모델의 매개변수와 학습률(learning rate) 하이퍼파라미터를 등록하여 옵티마이저를 초기화합니다.
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
- 학습 단계(loop)에서 최적화는 세단계로 이뤄집니다:
optimizer.zero_grad()
를 호출하여 모델 매개변수의 변화도를 재설정합니다. 기본적으로 변화도는 더해지기(add up) 때문에 중복 계산을 막기 위해 반복할 때마다 명시적으로 0으로 설정합니다.loss.backwards()
를 호출하여 예측 손실(prediction loss)을 역전파합니다. PyTorch는 각 매개변수에 대한 손실의 변화도를 저장합니다.변화도를 계산한 뒤에는
optimizer.step()
을 호출하여 역전파 단계에서 수집된 변화도로 매개변수를 조정합니다.
전체 구현¶
최적화 코드를 반복하여 수행하는 train_loop
와 테스트 데이터로 모델의 성능을 측정하는 test_loop
를 정의하였습니다.
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
# 모델을 학습(train) 모드로 설정합니다 - 배치 정규화(Batch Normalization) 및 드롭아웃(Dropout) 레이어들에 중요합니다.
# 이 예시에서는 없어도 되지만, 모범 사례를 위해 추가해두었습니다.
model.train()
for batch, (X, y) in enumerate(dataloader):
# 예측(prediction)과 손실(loss) 계산
pred = model(X)
loss = loss_fn(pred, y)
# 역전파
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), batch * batch_size + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
# 모델을 평가(eval) 모드로 설정합니다 - 배치 정규화(Batch Normalization) 및 드롭아웃(Dropout) 레이어들에 중요합니다.
# 이 예시에서는 없어도 되지만, 모범 사례를 위해 추가해두었습니다.
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
# torch.no_grad()를 사용하여 테스트 시 변화도(gradient)를 계산하지 않도록 합니다.
# 이는 requires_grad=True로 설정된 텐서들의 불필요한 변화도 연산 및 메모리 사용량 또한 줄여줍니다.
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
손실 함수와 옵티마이저를 초기화하고 train_loop
와 test_loop
에 전달합니다.
모델의 성능 향상을 알아보기 위해 자유롭게 에폭(epoch) 수를 증가시켜 볼 수 있습니다.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.298730 [ 64/60000]
loss: 2.289123 [ 6464/60000]
loss: 2.273286 [12864/60000]
loss: 2.269406 [19264/60000]
loss: 2.249604 [25664/60000]
loss: 2.229407 [32064/60000]
loss: 2.227369 [38464/60000]
loss: 2.204261 [44864/60000]
loss: 2.206193 [51264/60000]
loss: 2.166651 [57664/60000]
Test Error:
Accuracy: 50.9%, Avg loss: 2.166725
Epoch 2
-------------------------------
loss: 2.176751 [ 64/60000]
loss: 2.169596 [ 6464/60000]
loss: 2.117501 [12864/60000]
loss: 2.129273 [19264/60000]
loss: 2.079675 [25664/60000]
loss: 2.032928 [32064/60000]
loss: 2.050115 [38464/60000]
loss: 1.985237 [44864/60000]
loss: 1.987888 [51264/60000]
loss: 1.907163 [57664/60000]
Test Error:
Accuracy: 55.9%, Avg loss: 1.915487
Epoch 3
-------------------------------
loss: 1.951615 [ 64/60000]
loss: 1.928684 [ 6464/60000]
loss: 1.815711 [12864/60000]
loss: 1.841554 [19264/60000]
loss: 1.732469 [25664/60000]
loss: 1.692915 [32064/60000]
loss: 1.701716 [38464/60000]
loss: 1.610631 [44864/60000]
loss: 1.632872 [51264/60000]
loss: 1.514267 [57664/60000]
Test Error:
Accuracy: 58.8%, Avg loss: 1.541527
Epoch 4
-------------------------------
loss: 1.616449 [ 64/60000]
loss: 1.582892 [ 6464/60000]
loss: 1.427596 [12864/60000]
loss: 1.487955 [19264/60000]
loss: 1.359329 [25664/60000]
loss: 1.364820 [32064/60000]
loss: 1.371491 [38464/60000]
loss: 1.298707 [44864/60000]
loss: 1.336200 [51264/60000]
loss: 1.232144 [57664/60000]
Test Error:
Accuracy: 62.2%, Avg loss: 1.260238
Epoch 5
-------------------------------
loss: 1.345540 [ 64/60000]
loss: 1.327799 [ 6464/60000]
loss: 1.153804 [12864/60000]
loss: 1.254832 [19264/60000]
loss: 1.117318 [25664/60000]
loss: 1.153250 [32064/60000]
loss: 1.171764 [38464/60000]
loss: 1.110264 [44864/60000]
loss: 1.154467 [51264/60000]
loss: 1.070921 [57664/60000]
Test Error:
Accuracy: 64.1%, Avg loss: 1.089831
Epoch 6
-------------------------------
loss: 1.166888 [ 64/60000]
loss: 1.170515 [ 6464/60000]
loss: 0.979435 [12864/60000]
loss: 1.113774 [19264/60000]
loss: 0.973409 [25664/60000]
loss: 1.015192 [32064/60000]
loss: 1.051111 [38464/60000]
loss: 0.993591 [44864/60000]
loss: 1.039709 [51264/60000]
loss: 0.971078 [57664/60000]
Test Error:
Accuracy: 65.8%, Avg loss: 0.982441
Epoch 7
-------------------------------
loss: 1.045163 [ 64/60000]
loss: 1.070585 [ 6464/60000]
loss: 0.862304 [12864/60000]
loss: 1.022268 [19264/60000]
loss: 0.885212 [25664/60000]
loss: 0.919530 [32064/60000]
loss: 0.972762 [38464/60000]
loss: 0.918727 [44864/60000]
loss: 0.961630 [51264/60000]
loss: 0.904378 [57664/60000]
Test Error:
Accuracy: 66.9%, Avg loss: 0.910168
Epoch 8
-------------------------------
loss: 0.956964 [ 64/60000]
loss: 1.002171 [ 6464/60000]
loss: 0.779055 [12864/60000]
loss: 0.958410 [19264/60000]
loss: 0.827243 [25664/60000]
loss: 0.850261 [32064/60000]
loss: 0.917320 [38464/60000]
loss: 0.868385 [44864/60000]
loss: 0.905506 [51264/60000]
loss: 0.856354 [57664/60000]
Test Error:
Accuracy: 68.3%, Avg loss: 0.858248
Epoch 9
-------------------------------
loss: 0.889762 [ 64/60000]
loss: 0.951220 [ 6464/60000]
loss: 0.717033 [12864/60000]
loss: 0.911042 [19264/60000]
loss: 0.786091 [25664/60000]
loss: 0.798369 [32064/60000]
loss: 0.874938 [38464/60000]
loss: 0.832791 [44864/60000]
loss: 0.863253 [51264/60000]
loss: 0.819740 [57664/60000]
Test Error:
Accuracy: 69.5%, Avg loss: 0.818778
Epoch 10
-------------------------------
loss: 0.836395 [ 64/60000]
loss: 0.910217 [ 6464/60000]
loss: 0.668505 [12864/60000]
loss: 0.874332 [19264/60000]
loss: 0.754807 [25664/60000]
loss: 0.758451 [32064/60000]
loss: 0.840449 [38464/60000]
loss: 0.806151 [44864/60000]
loss: 0.830361 [51264/60000]
loss: 0.790275 [57664/60000]
Test Error:
Accuracy: 71.0%, Avg loss: 0.787269
Done!
더 읽어보기¶
Total running time of the script: ( 0 minutes 52.220 seconds)