• Tutorials >
  • 공간 변형 네트워크(Spatial Transformer Networks) 튜토리얼
Shortcuts

공간 변형 네트워크(Spatial Transformer Networks) 튜토리얼

저자: Ghassen HAMROUNI 번역: 황성수 , 정신유 .. figure:: /_static/img/stn/FSeq.png 이 튜토리얼에서는 공간 변형 네트워크(spatial transformer networks, 이하 STN)이라 불리는 비주얼 어텐션 메커니즘을 이용해 신경망을 증강(augment)시키는 방법에 대해 학습합니다. 이 방법에 대한 자세한 내용은 DeepMind paper 에서 확인할 수 있습니다. STN은 어떠한 공간적 변형(spatial transformation)에도 적용할 수 있는 미분 가능한 어텐션의 일반화입니다. 따라서 STN은 신경망의 기하학적 불변성(geometric invariance)을 강화하기 위해 입력 이미지를 대상으로 어떠한 공간적 변형을 수행해야 하는지 학습하도록 합니다. 예를 들어 이미지의 관심 영역을 잘라내거나, 크기를 조정하거나, 방향(orientation)을 수정할 수 있습니다. CNN은 이러한 회전, 크기 조정 등의 일반적인 아핀(affine) 변환된 입력에 대해 결과의 변동이 크기 때문에 (민감하기 때문에), STN은 이를 극복하는데 매우 유용한 메커니즘이 될 수 있습니다. STN이 가진 장점 중 하나는 아주 작은 수정만으로 기존에 사용하던 CNN에 간단하게 연결 시킬 수 있다는 것입니다.

# 라이센스: BSD
# 저자: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion()   # 대화형 모드

데이터 불러오기

이 튜토리얼에서는 MNIST 데이터셋을 이용해 실험합니다. 실험에는 STN으로 증강된 일반적인 CNN을 사용합니다.

from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 학습용 데이터셋
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)
# 테스트용 데이터셋
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Spatial Transformer Networks(STN) 구성하기

STN은 다음의 세 가지 주요 구성 요소로 요약됩니다.

  • 위치 결정 네트워크(localization network)는 공간 변환 파라미터를 예측(regress) 하는 일반적인 CNN 입니다. 공간 변환은 데이터 셋으로부터 명시적으로 학습되지 않고, 신경망이 전체 정확도를 향상 시키도록 공간 변환을 자동으로 학습합니다.

  • 그리드 생성기(grid generator)는 출력 이미지로부터 각 픽셀에 대응하는 입력 이미지 내 좌표 그리드를 생성합니다.

  • 샘플러(sampler)는 공간 변환 파라미터를 입력 이미지에 적용합니다.

../_images/stn-arch.png

Note

affine_grid 및 grid_sample 모듈이 포함된 최신 버전의 PyTorch가 필요합니다.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # 공간 변환을 위한 위치 결정 네트워크 (localization-network)
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # [3 * 2] 크기의 아핀(affine) 행렬에 대해 예측
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # 항등 변환(identity transformation)으로 가중치/바이어스 초기화
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # STN의 forward 함수
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # 입력을 변환
        x = self.stn(x)

        # 일반적인 forward pass를 수행
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)

모델 학습하기

이제 SGD 알고리즘을 이용해 모델을 학습시켜 봅시다. 앞서 구성한 신경망은 감독 학습 방식(supervised way)으로 분류 문제를 학습합니다. 또한 이 모델은 end-to-end 방식으로 STN을 자동으로 학습합니다.

optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#
# MNIST 데이터셋에서 STN의 성능을 측정하기 위한 간단한 테스트 절차
#


def test():
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # 배치 손실 합하기
            test_loss += F.nll_loss(output, target, size_average=False).item()
            # 로그-확률의 최대값에 해당하는 인덱스 가져오기
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

STN 결과 시각화하기

이제 학습한 비주얼 어텐션 메커니즘의 결과를 살펴보겠습니다.

학습하는 동안 변환된 결과를 시각화하기 위해 작은 도움(helper) 함수를 정의합니다.

def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# 학습 후 공간 변환 계층의 출력을 시각화하고, 입력 이미지 배치 데이터 및
# STN을 사용해 변환된 배치 데이터를 시각화 합니다.


def visualize_stn():
    with torch.no_grad():
        # 학습 데이터의 배치 가져오기
        data = next(iter(test_loader))[0].to(device)

        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()

        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))

        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))

        # 결과를 나란히 표시하기
        f, axarr = plt.subplots(1, 2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(out_grid)
        axarr[1].set_title('Transformed Images')

for epoch in range(1, 20 + 1):
    train(epoch)
    test()

# 일부 입력 배치 데이터에서 STN 변환 결과를 시각화
visualize_stn()

plt.ioff()
plt.show()
../_images/sphx_glr_spatial_transformer_tutorial_001.png

Out:

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.339093
Train Epoch: 1 [32000/60000 (53%)]      Loss: 0.806752

Test set: Average loss: 0.2708, Accuracy: 9280/10000 (93%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.677045
Train Epoch: 2 [32000/60000 (53%)]      Loss: 0.317856

Test set: Average loss: 0.1509, Accuracy: 9547/10000 (95%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.134617
Train Epoch: 3 [32000/60000 (53%)]      Loss: 0.286342

Test set: Average loss: 0.1011, Accuracy: 9708/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.180843
Train Epoch: 4 [32000/60000 (53%)]      Loss: 0.181035

Test set: Average loss: 0.0821, Accuracy: 9753/10000 (98%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.141197
Train Epoch: 5 [32000/60000 (53%)]      Loss: 0.101543

Test set: Average loss: 0.0838, Accuracy: 9735/10000 (97%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.128471
Train Epoch: 6 [32000/60000 (53%)]      Loss: 0.351965

Test set: Average loss: 0.0978, Accuracy: 9721/10000 (97%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.133197
Train Epoch: 7 [32000/60000 (53%)]      Loss: 0.160272

Test set: Average loss: 0.0579, Accuracy: 9815/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.223597
Train Epoch: 8 [32000/60000 (53%)]      Loss: 0.311896

Test set: Average loss: 0.0641, Accuracy: 9807/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.196691
Train Epoch: 9 [32000/60000 (53%)]      Loss: 0.056436

Test set: Average loss: 0.0523, Accuracy: 9834/10000 (98%)

Train Epoch: 10 [0/60000 (0%)]  Loss: 0.096855
Train Epoch: 10 [32000/60000 (53%)]     Loss: 0.121499

Test set: Average loss: 0.0776, Accuracy: 9772/10000 (98%)

Train Epoch: 11 [0/60000 (0%)]  Loss: 0.085548
Train Epoch: 11 [32000/60000 (53%)]     Loss: 0.296922

Test set: Average loss: 0.0437, Accuracy: 9869/10000 (99%)

Train Epoch: 12 [0/60000 (0%)]  Loss: 0.248544
Train Epoch: 12 [32000/60000 (53%)]     Loss: 0.101525

Test set: Average loss: 0.0446, Accuracy: 9870/10000 (99%)

Train Epoch: 13 [0/60000 (0%)]  Loss: 0.044299
Train Epoch: 13 [32000/60000 (53%)]     Loss: 0.015816

Test set: Average loss: 0.0427, Accuracy: 9877/10000 (99%)

Train Epoch: 14 [0/60000 (0%)]  Loss: 0.263352
Train Epoch: 14 [32000/60000 (53%)]     Loss: 0.327864

Test set: Average loss: 0.0713, Accuracy: 9782/10000 (98%)

Train Epoch: 15 [0/60000 (0%)]  Loss: 0.116053
Train Epoch: 15 [32000/60000 (53%)]     Loss: 0.043850

Test set: Average loss: 0.0459, Accuracy: 9862/10000 (99%)

Train Epoch: 16 [0/60000 (0%)]  Loss: 0.344532
Train Epoch: 16 [32000/60000 (53%)]     Loss: 0.140879

Test set: Average loss: 0.0480, Accuracy: 9856/10000 (99%)

Train Epoch: 17 [0/60000 (0%)]  Loss: 0.050756
Train Epoch: 17 [32000/60000 (53%)]     Loss: 0.121107

Test set: Average loss: 0.0620, Accuracy: 9798/10000 (98%)

Train Epoch: 18 [0/60000 (0%)]  Loss: 0.055140
Train Epoch: 18 [32000/60000 (53%)]     Loss: 0.030762

Test set: Average loss: 0.0417, Accuracy: 9874/10000 (99%)

Train Epoch: 19 [0/60000 (0%)]  Loss: 0.107585
Train Epoch: 19 [32000/60000 (53%)]     Loss: 0.062222

Test set: Average loss: 0.0372, Accuracy: 9885/10000 (99%)

Train Epoch: 20 [0/60000 (0%)]  Loss: 0.113024
Train Epoch: 20 [32000/60000 (53%)]     Loss: 0.096670

Test set: Average loss: 0.0472, Accuracy: 9864/10000 (99%)

Total running time of the script: ( 2 minutes 40.842 seconds)

Gallery generated by Sphinx-Gallery


이 튜토리얼이 어떠셨나요?

© Copyright 2021, PyTorch & PyTorch Korea Community.

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동