Shortcuts

분류기(Classifier) 학습하기

지금까지 어떻게 신경망을 정의하고, 손실을 계산하며 또 가중치를 갱신하는지에 대해서 배웠습니다.

이제 아마도 이런 생각을 하고 계실텐데요,

데이터는 어떻게 하나요?

일반적으로 이미지나 텍스트, 오디오나 비디오 데이터를 다룰 때는 표준 Python 패키지를 이용하여 NumPy 배열로 불러오면 됩니다. 그 후 그 배열을 torch.*Tensor 로 변환합니다.

  • 이미지는 Pillow나 OpenCV 같은 패키지가 유용합니다.

  • 오디오를 처리할 때는 SciPy와 LibROSA가 유용하고요.

  • 텍스트의 경우에는 그냥 Python이나 Cython을 사용해도 되고, NLTK나 SpaCy도 유용합니다.

특별히 영상 분야를 위한 torchvision 이라는 패키지가 만들어져 있는데, 여기에는 ImageNet이나 CIFAR10, MNIST 등과 같이 일반적으로 사용하는 데이터셋을 위한 데이터 로더(data loader), 즉 torchvision.datasets 과 이미지용 데이터 변환기 (data transformer), 즉 torch.utils.data.DataLoader 가 포함되어 있습니다.

이러한 기능은 엄청나게 편리하며, 매번 유사한 코드(boilerplate code)를 반복해서 작성하는 것을 피할 수 있습니다.

이 튜토리얼에서는 CIFAR10 데이터셋을 사용합니다. 여기에는 다음과 같은 분류들이 있습니다: 〈비행기(airplane)〉, 〈자동차(automobile)〉, 〈새(bird)〉, 〈고양이(cat)〉, 〈사슴(deer)〉, 〈개(dog)〉, 〈개구리(frog)〉, 〈말(horse)〉, 〈배(ship)〉, 〈트럭(truck)〉. 그리고 CIFAR10에 포함된 이미지의 크기는 3x32x32로, 이는 32x32 픽셀 크기의 이미지가 3개 채널(channel)의 색상으로 이뤄져 있다는 것을 뜻합니다.

cifar10

cifar10

이미지 분류기 학습하기

다음과 같은 단계로 진행해보겠습니다:

  1. torchvision 을 사용하여 CIFAR10의 학습용 / 시험용 데이터셋을 불러오고, 정규화(nomarlizing)합니다.

  2. 합성곱 신경망(Convolution Neural Network)을 정의합니다.

  3. 손실 함수를 정의합니다.

  4. 학습용 데이터를 사용하여 신경망을 학습합니다.

  5. 시험용 데이터를 사용하여 신경망을 검사합니다.

1. CIFAR10을 불러오고 정규화하기

torchvision 을 사용하여 매우 쉽게 CIFAR10을 불러올 수 있습니다.

import torch
import torchvision
import torchvision.transforms as transforms

torchvision 데이터셋의 출력(output)은 [0, 1] 범위를 갖는 PILImage 이미지입니다. 이를 [-1, 1]의 범위로 정규화된 Tensor로 변환합니다.

참고

만약 Windows 환경에서 BrokenPipeError가 발생한다면, torch.utils.data.DataLoader()의 num_worker를 0으로 설정해보세요.

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz

  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 32768/170498071 [00:00<20:00, 141986.64it/s]
  0%|          | 65536/170498071 [00:00<15:37, 181821.51it/s]
  0%|          | 98304/170498071 [00:00<15:44, 180339.77it/s]
  0%|          | 229376/170498071 [00:00<07:21, 385955.05it/s]
  0%|          | 458752/170498071 [00:00<04:05, 693026.28it/s]
  1%|          | 917504/170498071 [00:01<02:10, 1294563.97it/s]
  1%|1         | 1835008/170498071 [00:01<01:07, 2481190.16it/s]
  2%|2         | 3670016/170498071 [00:01<00:34, 4818519.96it/s]
  4%|3         | 6815744/170498071 [00:01<00:19, 8565329.45it/s]
  6%|5         | 9568256/170498071 [00:01<00:13, 12202240.42it/s]
  6%|6         | 11010048/170498071 [00:01<00:13, 11598964.54it/s]
  8%|7         | 13074432/170498071 [00:02<00:12, 12668050.99it/s]
  9%|9         | 15597568/170498071 [00:02<00:10, 15477128.17it/s]
 10%|#         | 17334272/170498071 [00:02<00:10, 14124258.27it/s]
 11%|#1        | 19103744/170498071 [00:02<00:11, 12792397.82it/s]
 13%|#2        | 21987328/170498071 [00:02<00:10, 13608958.28it/s]
 15%|#4        | 24936448/170498071 [00:02<00:10, 14184053.29it/s]
 16%|#6        | 27852800/170498071 [00:03<00:09, 14513083.58it/s]
 18%|#7        | 30670848/170498071 [00:03<00:09, 14616155.04it/s]
 20%|#9        | 33521664/170498071 [00:03<00:09, 14692382.11it/s]
 21%|##1       | 36470784/170498071 [00:03<00:08, 14922168.00it/s]
 23%|##3       | 39321600/170498071 [00:03<00:08, 15011451.75it/s]
 25%|##4       | 42172416/170498071 [00:04<00:08, 15004802.28it/s]
 26%|##6       | 45056000/170498071 [00:04<00:08, 15114239.80it/s]
 28%|##8       | 48168960/170498071 [00:04<00:07, 15593692.85it/s]
 30%|###       | 51216384/170498071 [00:04<00:07, 15525275.35it/s]
 32%|###1      | 54165504/170498071 [00:04<00:07, 15467393.39it/s]
 33%|###3      | 57081856/170498071 [00:04<00:07, 15376664.42it/s]
 35%|###5      | 59998208/170498071 [00:05<00:07, 15334058.99it/s]
 37%|###6      | 63045632/170498071 [00:05<00:06, 15547437.65it/s]
 39%|###8      | 66093056/170498071 [00:05<00:06, 15642183.69it/s]
 40%|####      | 69009408/170498071 [00:05<00:06, 15546311.06it/s]
 42%|####2     | 71925760/170498071 [00:05<00:06, 15483565.62it/s]
 44%|####3     | 74776576/170498071 [00:06<00:06, 15421406.99it/s]
 46%|####5     | 77922304/170498071 [00:06<00:05, 15851383.10it/s]
 48%|####7     | 81068032/170498071 [00:06<00:05, 16101136.75it/s]
 49%|####9     | 84180992/170498071 [00:06<00:05, 16300559.54it/s]
 51%|#####1    | 87228416/170498071 [00:06<00:05, 15762503.81it/s]
 53%|#####3    | 90374144/170498071 [00:07<00:04, 16048523.37it/s]
 55%|#####4    | 93487104/170498071 [00:07<00:04, 16306638.45it/s]
 57%|#####6    | 96632832/170498071 [00:07<00:04, 16336088.33it/s]
 59%|#####8    | 99745792/170498071 [00:07<00:04, 16546310.60it/s]
 60%|######    | 102891520/170498071 [00:07<00:04, 16553855.70it/s]
 62%|######2   | 106004480/170498071 [00:07<00:03, 16652662.97it/s]
 64%|######3   | 109084672/170498071 [00:08<00:04, 15232278.07it/s]
 66%|######5   | 112001024/170498071 [00:08<00:03, 15246210.66it/s]
 67%|######7   | 114851840/170498071 [00:08<00:03, 15167998.44it/s]
 69%|######9   | 117702656/170498071 [00:08<00:03, 15174223.67it/s]
 71%|#######   | 120553472/170498071 [00:08<00:03, 15180491.53it/s]
 73%|#######2  | 123699200/170498071 [00:09<00:02, 15673095.92it/s]
 74%|#######4  | 126844928/170498071 [00:09<00:02, 16035226.18it/s]
 76%|#######6  | 129957888/170498071 [00:09<00:02, 16258144.85it/s]
 77%|#######7  | 131596288/170498071 [00:09<00:03, 10661812.02it/s]
 79%|#######8  | 134021120/170498071 [00:10<00:03, 11240218.36it/s]
 80%|#######9  | 135954432/170498071 [00:10<00:03, 11046689.10it/s]
 82%|########1 | 139067392/170498071 [00:10<00:02, 12603025.28it/s]
 83%|########3 | 142180352/170498071 [00:10<00:02, 13746925.94it/s]
 85%|########5 | 145326080/170498071 [00:10<00:01, 14603169.47it/s]
 87%|########7 | 148439040/170498071 [00:11<00:01, 15263591.00it/s]
 89%|########8 | 151584768/170498071 [00:11<00:01, 15707637.47it/s]
 91%|######### | 154697728/170498071 [00:11<00:00, 16074518.51it/s]
 93%|#########2| 157843456/170498071 [00:11<00:00, 16189479.63it/s]
 94%|#########4| 160989184/170498071 [00:11<00:00, 16351803.09it/s]
 96%|#########6| 164102144/170498071 [00:12<00:00, 15019325.45it/s]
 98%|#########8| 167215104/170498071 [00:12<00:00, 17221414.44it/s]
 99%|#########9| 169082880/170498071 [00:12<00:00, 17494346.90it/s]
100%|##########| 170498071/170498071 [00:12<00:00, 13701475.77it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

재미삼아 학습용 이미지 몇 개를 보겠습니다.

import matplotlib.pyplot as plt
import numpy as np

# 이미지를 보여주기 위한 함수

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 학습용 이미지를 무작위로 가져오기
dataiter = iter(trainloader)
images, labels = next(dataiter)

# 이미지 보여주기
imshow(torchvision.utils.make_grid(images))
# 정답(label) 출력
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
cifar10 tutorial
frog  car   deer  cat

2. 합성곱 신경망(Convolution Neural Network) 정의하기

이전의 신경망 섹션에서 신경망을 복사한 후, (기존에 1채널 이미지만 처리하도록 정의된 것을) 3채널 이미지를 처리할 수 있도록 수정합니다.

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # 배치를 제외한 모든 차원을 평탄화(flatten)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. 손실 함수와 Optimizer 정의하기

교차 엔트로피 손실(Cross-Entropy loss)과 모멘텀(momentum) 값을 갖는 SGD를 사용합니다.

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 신경망 학습하기

이제 재미있는 부분이 시작됩니다. 단순히 데이터를 반복해서 신경망에 입력으로 제공하고, 최적화(Optimize)만 하면 됩니다.

for epoch in range(2):   # 데이터셋을 수차례 반복합니다.

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # [inputs, labels]의 목록인 data로부터 입력을 받은 후;
        inputs, labels = data

        # 변화도(Gradient) 매개변수를 0으로 만들고
        optimizer.zero_grad()

        # 순전파 + 역전파 + 최적화를 한 후
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 통계를 출력합니다.
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')
[1,  2000] loss: 2.164
[1,  4000] loss: 1.825
[1,  6000] loss: 1.661
[1,  8000] loss: 1.580
[1, 10000] loss: 1.531
[1, 12000] loss: 1.471
[2,  2000] loss: 1.412
[2,  4000] loss: 1.359
[2,  6000] loss: 1.333
[2,  8000] loss: 1.335
[2, 10000] loss: 1.316
[2, 12000] loss: 1.270
Finished Training

학습한 모델을 저장해보겠습니다:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

PyTorch 모델을 저장하는 자세한 방법은 여기 를 참조해주세요.

5. 시험용 데이터로 신경망 검사하기

지금까지 학습용 데이터셋을 2회 반복하며 신경망을 학습시켰습니다. 신경망이 전혀 배운게 없을지도 모르니 확인해봅니다.

신경망이 예측한 출력과 진짜 정답(Ground-truth)을 비교하는 방식으로 확인합니다. 만약 예측이 맞다면 샘플을 〈맞은 예측값(correct predictions)〉 목록에 넣겠습니다.

첫번째로 시험용 데이터를 좀 보겠습니다.

dataiter = iter(testloader)
images, labels = next(dataiter)

# 이미지를 출력합니다.
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
cifar10 tutorial
GroundTruth:  cat   ship  ship  plane

이제, 저장했던 모델을 불러오도록 하겠습니다 (주: 모델을 저장하고 다시 불러오는 작업은 여기에서는 불필요하지만, 어떻게 하는지 설명을 위해 해보겠습니다):

net = Net()
net.load_state_dict(torch.load(PATH))
<All keys matched successfully>

좋습니다, 이제 이 예제들을 신경망이 어떻게 예측했는지를 보겠습니다:

outputs = net(images)

출력은 10개 분류 각각에 대한 값으로 나타납니다. 어떤 분류에 대해서 더 높은 값이 나타난다는 것은, 신경망이 그 이미지가 해당 분류에 더 가깝다고 생각한다는 것입니다. 따라서, 가장 높은 값을 갖는 인덱스(index)를 뽑아보겠습니다:

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))
Predicted:  cat   car   plane plane

결과가 괜찮아보이네요.

그럼 전체 데이터셋에 대해서는 어떻게 동작하는지 보겠습니다.

correct = 0
total = 0
# 학습 중이 아니므로, 출력에 대한 변화도를 계산할 필요가 없습니다
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # 신경망에 이미지를 통과시켜 출력을 계산합니다
        outputs = net(images)
        # 가장 높은 값(energy)를 갖는 분류(class)를 정답으로 선택하겠습니다
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 55 %

(10가지 분류 중에 하나를 무작위로) 찍었을 때의 정확도인 10% 보다는 나아보입니다. 신경망이 뭔가 배우긴 한 것 같네요.

그럼 어떤 것들을 더 잘 분류하고, 어떤 것들을 더 못했는지 알아보겠습니다:

# 각 분류(class)에 대한 예측값 계산을 위해 준비
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# 변화도는 여전히 필요하지 않습니다
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # 각 분류별로 올바른 예측 수를 모읍니다
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# 각 분류별 정확도(accuracy)를 출력합니다
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 62.4 %
Accuracy for class: car   is 72.6 %
Accuracy for class: bird  is 36.6 %
Accuracy for class: cat   is 37.1 %
Accuracy for class: deer  is 35.3 %
Accuracy for class: dog   is 54.4 %
Accuracy for class: frog  is 83.3 %
Accuracy for class: horse is 54.5 %
Accuracy for class: ship  is 60.5 %
Accuracy for class: truck is 56.9 %

자, 이제 다음으로 무엇을 해볼까요?

이러한 신경망들을 GPU에서 실행하려면 어떻게 해야 할까요?

GPU에서 학습하기

Tensor를 GPU로 이동했던 것처럼, 신경망 또한 GPU로 옮길 수 있습니다.

먼저 (CUDA를 사용할 수 있다면) 첫번째 CUDA 장치를 사용하도록 설정합니다:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# CUDA 기기가 존재한다면, 아래 코드가 CUDA 장치를 출력합니다:

print(device)
cuda:0

이 섹션의 나머지 부분에서는 device 를 CUDA 장치라고 가정하겠습니다.

그리고 이 메소드(Method)들은 재귀적으로 모든 모듈의 매개변수와 버퍼를 CUDA tensor로 변경합니다:

net.to(device)

또한, 각 단계에서 입력(input)과 정답(target)도 GPU로 보내야 한다는 것도 기억해야 합니다:

inputs, labels = data[0].to(device), data[1].to(device)

CPU와 비교했을 때 어마어마한 속도 차이가 나지 않는 것은 왜 그럴까요? 그 이유는 바로 신경망이 너무 작기 때문입니다.

연습: 신경망의 크기를 키워보고, 얼마나 빨라지는지 확인해보세요. (첫번째 nn.Conv2d 의 2번째 인자와 두번째 nn.Conv2d 의 1번째 인자는 같은 숫자여야 합니다.)

다음 목표들을 달성했습니다:

  • 높은 수준에서 PyTorch의 Tensor library와 신경망을 이해합니다.

  • 이미지를 분류하는 작은 신경망을 학습시킵니다.

여러개의 GPU에서 학습하기

모든 GPU를 활용해서 더욱 더 속도를 올리고 싶다면, 선택 사항: 데이터 병렬 처리 (Data Parallelism) 을 참고하세요.


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동