• Tutorials >
  • 컴퓨터 비전(Vision)을 위한 전이학습(Transfer Learning)
Shortcuts

컴퓨터 비전(Vision)을 위한 전이학습(Transfer Learning)

Author: Sasank Chilamkurthy

번역: 박정환

이 튜토리얼에서는 전이학습(Transfer Learning)을 이용하여 이미지 분류를 위한 합성곱 신경망을 어떻게 학습시키는지 배워보겠습니다. 전이학습에 대해서는 CS231n 노트 에서 더 많은 내용을 읽어보실 수 있습니다.

위 노트를 인용해보면,

실제로 충분한 크기의 데이터셋을 갖추기는 상대적으로 드물기 때문에, (무작위 초기화를 통해) 맨 처음부터 합성곱 신경망(Convolutional Network) 전체를 학습하는 사람은 매우 적습니다. 대신, 매우 큰 데이터셋(예. 100가지 분류에 대해 120만개의 이미지가 포함된 ImageNet)에서 합성곱 신경망(ConvNet)을 미리 학습한 후, 이 합성곱 신경망을 관심있는 작업 을 위한 초기 설정 또는 고정된 특징 추출기(fixed feature extractor)로 사용합니다.

이러한 전이학습 시나리오의 주요한 2가지는 다음과 같습니다:

  • 합성곱 신경망의 미세조정(finetuning): 무작위 초기화 대신, 신경망을 ImageNet 1000 데이터셋 등으로 미리 학습한 신경망으로 초기화합니다. 학습의 나머지 과정들은 평상시와 같습니다.

  • 고정된 특징 추출기로써의 합성곱 신경망: 여기서는 마지막에 완전히 연결 된 계층을 제외한 모든 신경망의 가중치를 고정합니다. 이 마지막의 완전히 연결된 계층은 새로운 무작위의 가중치를 갖는 계층으로 대체되어 이 계층만 학습합니다.

# License: BSD
# Author: Sasank Chilamkurthy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()   # 대화형 모드
<contextlib.ExitStack object at 0x7efd80764ee0>

데이터 불러오기

데이터를 불러오기 위해 torchvision과 torch.utils.data 패키지를 사용하겠습니다.

여기서 풀고자 하는 문제는 개미 을 분류하는 모델을 학습하는 것입니다. 개미와 벌 각각의 학습용 이미지는 대략 120장 정도 있고, 75개의 검증용 이미지가 있습니다. 일반적으로 맨 처음부터 학습을 한다면 이는 일반화하기에는 아주 작은 데이터셋입니다. 하지만 우리는 전이학습을 할 것이므로, 일반화를 제법 잘 할 수 있을 것입니다.

이 데이터셋은 ImageNet의 아주 작은 일부입니다.

참고

데이터를 여기 에서 다운로드 받아 현재 디렉토리에 압축을 푸십시오.

# 학습을 위해 데이터 증가(augmentation) 및 일반화(normalization)
# 검증을 위한 일반화
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

일부 이미지 시각화하기

데이터 증가를 이해하기 위해 일부 학습용 이미지를 시각화해보겠습니다.

def imshow(inp, title=None):
    """tensor를 입력받아 일반적인 이미지로 보여줍니다."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # 갱신이 될 때까지 잠시 기다립니다.


# 학습 데이터의 배치를 얻습니다.
inputs, classes = next(iter(dataloaders['train']))

# 배치로부터 격자 형태의 이미지를 만듭니다.
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
['ants', 'ants', 'ants', 'ants']

모델 학습하기

이제 모델을 학습하기 위한 일반 함수를 작성해보겠습니다. 여기서는 다음 내용들을 설명합니다:

  • 학습률(learning rate) 관리(scheduling)

  • 최적의 모델 구하기

아래에서 scheduler 매개변수는 torch.optim.lr_scheduler 의 LR 스케쥴러 객체(Object)입니다.

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # 각 에폭(epoch)은 학습 단계와 검증 단계를 갖습니다.
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # 모델을 학습 모드로 설정
                else:
                    model.eval()   # 모델을 평가 모드로 설정

                running_loss = 0.0
                running_corrects = 0

                # 데이터를 반복
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # 매개변수 경사도를 0으로 설정
                    optimizer.zero_grad()

                    # 순전파
                    # 학습 시에만 연산 기록을 추적
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # 학습 단계인 경우 역전파 + 최적화
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # 통계
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # 모델을 깊은 복사(deep copy)함
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # 가장 나은 모델 가중치를 불러오기
        model.load_state_dict(torch.load(best_model_params_path))
    return model

모델 예측값 시각화하기

일부 이미지에 대한 예측값을 보여주는 일반화된 함수입니다.

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

합성곱 신경망 미세조정(finetuning)

미리 학습한 모델을 불러온 후 마지막의 완전히 연결된 계층을 초기화합니다.

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# 여기서 각 출력 샘플의 크기는 2로 설정합니다.
# 또는, ``nn.Linear(num_ftrs, len (class_names))`` 로 일반화할 수 있습니다.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# 모든 매개변수들이 최적화되었는지 관찰
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 7 에폭마다 0.1씩 학습률 감소
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

  0%|          | 0.00/44.7M [00:00<?, ?B/s]
 13%|#3        | 6.00M/44.7M [00:00<00:00, 62.4MB/s]
 34%|###4      | 15.2M/44.7M [00:00<00:00, 82.2MB/s]
 54%|#####3    | 24.0M/44.7M [00:00<00:00, 83.2MB/s]
 79%|#######8  | 35.1M/44.7M [00:00<00:00, 87.3MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 91.2MB/s]

학습 및 평가하기

CPU에서는 15-25분 가량, GPU에서는 1분 이내의 시간이 걸립니다.

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
Epoch 0/24
----------
train Loss: 0.4752 Acc: 0.7623
val Loss: 0.3116 Acc: 0.8431

Epoch 1/24
----------
train Loss: 0.5310 Acc: 0.7951
val Loss: 0.6484 Acc: 0.7320

Epoch 2/24
----------
train Loss: 0.4358 Acc: 0.8279
val Loss: 0.2984 Acc: 0.8954

Epoch 3/24
----------
train Loss: 0.6276 Acc: 0.7582
val Loss: 0.2483 Acc: 0.8824

Epoch 4/24
----------
train Loss: 0.4285 Acc: 0.8361
val Loss: 0.4031 Acc: 0.8301

Epoch 5/24
----------
train Loss: 0.4698 Acc: 0.8033
val Loss: 0.3513 Acc: 0.8627

Epoch 6/24
----------
train Loss: 0.3734 Acc: 0.8320
val Loss: 0.3077 Acc: 0.8824

Epoch 7/24
----------
train Loss: 0.3927 Acc: 0.8320
val Loss: 0.2228 Acc: 0.9216

Epoch 8/24
----------
train Loss: 0.2311 Acc: 0.9221
val Loss: 0.2048 Acc: 0.9346

Epoch 9/24
----------
train Loss: 0.2775 Acc: 0.8689
val Loss: 0.2213 Acc: 0.9412

Epoch 10/24
----------
train Loss: 0.3539 Acc: 0.8648
val Loss: 0.1914 Acc: 0.9346

Epoch 11/24
----------
train Loss: 0.3226 Acc: 0.8402
val Loss: 0.2916 Acc: 0.8824

Epoch 12/24
----------
train Loss: 0.2390 Acc: 0.8975
val Loss: 0.2258 Acc: 0.9281

Epoch 13/24
----------
train Loss: 0.2939 Acc: 0.8689
val Loss: 0.1980 Acc: 0.9346

Epoch 14/24
----------
train Loss: 0.2836 Acc: 0.8975
val Loss: 0.2552 Acc: 0.9085

Epoch 15/24
----------
train Loss: 0.2990 Acc: 0.8648
val Loss: 0.2777 Acc: 0.8693

Epoch 16/24
----------
train Loss: 0.2145 Acc: 0.9221
val Loss: 0.2124 Acc: 0.9150

Epoch 17/24
----------
train Loss: 0.2535 Acc: 0.8852
val Loss: 0.1935 Acc: 0.9346

Epoch 18/24
----------
train Loss: 0.2838 Acc: 0.8770
val Loss: 0.2387 Acc: 0.9020

Epoch 19/24
----------
train Loss: 0.2181 Acc: 0.9057
val Loss: 0.1948 Acc: 0.9216

Epoch 20/24
----------
train Loss: 0.2790 Acc: 0.8607
val Loss: 0.2142 Acc: 0.9216

Epoch 21/24
----------
train Loss: 0.2718 Acc: 0.8852
val Loss: 0.2594 Acc: 0.8889

Epoch 22/24
----------
train Loss: 0.3405 Acc: 0.8607
val Loss: 0.2186 Acc: 0.9216

Epoch 23/24
----------
train Loss: 0.2734 Acc: 0.8648
val Loss: 0.2258 Acc: 0.9281

Epoch 24/24
----------
train Loss: 0.3175 Acc: 0.8689
val Loss: 0.2012 Acc: 0.9281

Training complete in 0m 31s
Best val Acc: 0.941176
visualize_model(model_ft)
predicted: ants, predicted: bees, predicted: ants, predicted: bees, predicted: bees, predicted: ants

고정된 특징 추출기로써의 합성곱 신경망

이제, 마지막 계층을 제외한 신경망의 모든 부분을 고정해야 합니다. requires_grad = False 로 설정하여 매개변수를 고정하여 backward() 중에 경사도가 계산되지 않도록 해야합니다.

이에 대한 문서는 여기 에서 확인할 수 있습니다.

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False

# 새로 생성된 모듈의 매개변수는 기본값이 requires_grad=True 임
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# 이전과는 다르게 마지막 계층의 매개변수들만 최적화되는지 관찰
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 7 에폭마다 0.1씩 학습률 감소
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

학습 및 평가하기

CPU에서 실행하는 경우 이전과 비교했을 때 약 절반 가량의 시간만이 소요될 것입니다. 이는 대부분의 신경망에서 경사도를 계산할 필요가 없기 때문입니다. 하지만, 순전파는 계산이 필요합니다.

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6996 Acc: 0.6516
val Loss: 0.2014 Acc: 0.9346

Epoch 1/24
----------
train Loss: 0.4233 Acc: 0.8033
val Loss: 0.2656 Acc: 0.8758

Epoch 2/24
----------
train Loss: 0.4603 Acc: 0.7869
val Loss: 0.1847 Acc: 0.9477

Epoch 3/24
----------
train Loss: 0.3096 Acc: 0.8566
val Loss: 0.1747 Acc: 0.9477

Epoch 4/24
----------
train Loss: 0.4427 Acc: 0.8156
val Loss: 0.1631 Acc: 0.9477

Epoch 5/24
----------
train Loss: 0.5505 Acc: 0.7828
val Loss: 0.1643 Acc: 0.9477

Epoch 6/24
----------
train Loss: 0.3004 Acc: 0.8607
val Loss: 0.1744 Acc: 0.9542

Epoch 7/24
----------
train Loss: 0.4082 Acc: 0.8361
val Loss: 0.1893 Acc: 0.9412

Epoch 8/24
----------
train Loss: 0.4484 Acc: 0.7910
val Loss: 0.1984 Acc: 0.9477

Epoch 9/24
----------
train Loss: 0.3335 Acc: 0.8279
val Loss: 0.1942 Acc: 0.9412

Epoch 10/24
----------
train Loss: 0.2413 Acc: 0.8934
val Loss: 0.2001 Acc: 0.9477

Epoch 11/24
----------
train Loss: 0.3107 Acc: 0.8689
val Loss: 0.1801 Acc: 0.9412

Epoch 12/24
----------
train Loss: 0.3032 Acc: 0.8689
val Loss: 0.1669 Acc: 0.9477

Epoch 13/24
----------
train Loss: 0.3587 Acc: 0.8525
val Loss: 0.1900 Acc: 0.9477

Epoch 14/24
----------
train Loss: 0.2771 Acc: 0.8893
val Loss: 0.2317 Acc: 0.9216

Epoch 15/24
----------
train Loss: 0.3064 Acc: 0.8852
val Loss: 0.1909 Acc: 0.9477

Epoch 16/24
----------
train Loss: 0.4243 Acc: 0.8238
val Loss: 0.2227 Acc: 0.9346

Epoch 17/24
----------
train Loss: 0.3297 Acc: 0.8238
val Loss: 0.1917 Acc: 0.9412

Epoch 18/24
----------
train Loss: 0.4235 Acc: 0.8238
val Loss: 0.1766 Acc: 0.9477

Epoch 19/24
----------
train Loss: 0.2500 Acc: 0.8934
val Loss: 0.2003 Acc: 0.9477

Epoch 20/24
----------
train Loss: 0.2413 Acc: 0.8934
val Loss: 0.1820 Acc: 0.9477

Epoch 21/24
----------
train Loss: 0.3762 Acc: 0.8115
val Loss: 0.1842 Acc: 0.9412

Epoch 22/24
----------
train Loss: 0.3484 Acc: 0.8566
val Loss: 0.2166 Acc: 0.9281

Epoch 23/24
----------
train Loss: 0.3626 Acc: 0.8361
val Loss: 0.1747 Acc: 0.9412

Epoch 24/24
----------
train Loss: 0.3840 Acc: 0.8320
val Loss: 0.1767 Acc: 0.9412

Training complete in 0m 22s
Best val Acc: 0.954248
visualize_model(model_conv)

plt.ioff()
plt.show()
predicted: bees, predicted: ants, predicted: bees, predicted: bees, predicted: ants, predicted: ants

다른 이미지들에 대한 추론

학습된 모델을 사용하여 사용자 지정 이미지에 대해 예측하고, 예측된 클래스 레이블을 이미지와 함께 시각화합니다.

def visualize_model_predictions(model,img_path):
    was_training = model.training
    model.eval()

    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)

        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f'Predicted: {class_names[preds[0]]}')
        imshow(img.cpu().data[0])

        model.train(mode=was_training)
visualize_model_predictions(
    model_conv,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()
Predicted: bees

더 배워볼 내용

전이학습의 응용 사례(application)들을 더 알아보려면, (베타) 컴퓨터 비전 튜토리얼을 위한 양자화된 전이학습(Quantized Transfer Learning) 을 참조해보세요.

Total running time of the script: ( 0 minutes 55.854 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동