Shortcuts

DCGAN 튜토리얼

저자: Nathan Inkawhich 번역: 조민성

개요

본 튜토리얼에서는 예제를 통해 DCGAN을 알아보겠습니다. 우리는 실제 유명인들의 사진들로 적대적 생성 신경망(GAN)을 학습시켜, 새로운 유명인의 사진을 만들어보겠습니다. 사용할 대부분의 코드는 pytorch/examples 의 DCGAN 구현에서 가져왔으며, 본 문서는 구현에 대한 설명과 함께, 어째서 이 모델이 작동하는지에 대해 설명을 해줄 것입니다. 처음 읽었을때는, 실제로 모델에 무슨일이 일어나고 있는지에 대해 이해하는 것이 조금 시간을 소요할 수 있으나, 그래도 GAN에 대한 사전지식이 필요하지는 않으니 걱정하지 않으셔도 됩니다. 추가로, GPU 1-2개를 사용하는 것이 시간절약에 도움이 될겁니다. 그럼 처음부터 천천히 시작해봅시다!

적대적 생성 신경망(Generative Adversarial Networks)

그래서 GAN이 뭘까요?

GAN이란 학습 데이터들의 분포를 학습한 뒤, 동일한 분포를 갖는 새로운 데이터를 생성하도록 딥러닝 모델을 학습시키는 프레임워크입니다. GAN은 2014년 Ian Goodfellow가 개발했으며, Generative Adversarial Nets 논문에서 처음 소개되었습니다. GAN은 생성자(Generator)구분자(Discriminator) 라는 두 개의 서로 다른(distinct) 모델들로 구성되어 있습니다. 생성자(Generator)의 역할은 학습한 이미지들과 같아 보이는 가짜(fake) 이미지를 만드는 것이고, 구분자(Discriminator)는 이미지를 보고 이것이 실제 학습 데이터에서 가져온 것인지, 또는 생성자에 의해 만들어진 가짜 이미지인지 판별하는 것입니다. 모델을 학습하는 동안 생성자는 더 진짜 같은 가짜 이미지를 만들어내며 구분자를 속이려 하고, 구분자는 진짜 이미지와 가짜 이미지를 더 정확히 판별할 수 있도록 노력합니다. 이러한 과정은 생성자가 마치 학습 데이터에서 가져온 것처럼 보이는 완벽한 가짜 이미지를 생성해내고, 판별자는 항상 50%의 신뢰도로 생성자의 출력이 진짜인지 가짜인지 판별할 수 있을 때 균형 상태(equilbrium)에 도달하게 됩니다.

그럼 이제부터 본 튜토리얼에서 사용할 표기들을 구분자부터 정의해보겠습니다. \(x\) 는 이미지로 표현되는 데이터라고 하겠습니다. \(D(x)\) 는 구분자의 신경망을 나타내며, 실제 학습 데이터에서 가져온 \(x\) 를 통과시켜 확률 값(scalar)을 결과로 출력합니다. 여기에서는 이미지 데이터를 다루고 있으므로, \(D(x)\) 의 입력으로는 3x64x64 크기의 CHW 이미지가 주어집니다. 직관적으로 \(D(x)\)\(x\) 가 학습 데이터에서 가져왔을 때 출력이 크고(HIGH), 생성자가 만들어낸 \(x\) 일 때는 작을(LOW) 것입니다. \(D(x)\) 는 전통적인 이진 분류기(binary classification)로도 생각할 수도 있습니다.

이번엔 생성자의 표기들을 살펴보겠습니다. \(z\) 를 정규분포에서 뽑은 잠재공간 벡터(laten space vector)라고 하겠습니다 (번역 주. laten space vector는 쉽게 생각해 정규분포를 따르는 n개의 원소를 가진 vector라 볼 수 있습니다. 다르게 얘기하면 정규분포에서 n개의 원소를 추출한 것과 같습니다). \(G(z)\)\(z\) 벡터를 원하는 데이터 차원으로 대응시키는 신경망으로 둘 수 있습니다. 이때 \(G\) 의 목적은 \(p_{data}\) 에서 얻을 수 있는 학습 데이터들의 분포를 추정하여, 모사한 \(p_g\) 의 분포를 이용해 가짜 데이터들을 만드는 것입니다.

이어서, \(D(G(z))\)\(G\) 가 출력한 결과물이 실제 이미지 여부를 나타내는 0~1 사이의 확률 값(scalar)입니다. Goodfellow의 논문 에 기술되어 있듯이, \(D\)\(G\) 는 일종의 최대-최소 게임(minimax game)을 하고 있는 것과 같습니다. 이는 \(D\) 는 이미지가 진짜인지 가짜인지 여부를 판별하는 확률인 \(logD(x)\) 를 최대화하려고 하고, \(G\)\(D\) 가 가짜라고 판별할 확률인 \(log(1-D(G(z)))\) 를 최소화시키려고 하기 때문입니다. 논문에 따르면, GAN의 손실함수는 아래와 같습니다.

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] \]

이론적으로는, 이 최대-최소 게임의 답(solution)은 \(p_g = p_{data}\) 일 때이며, 이 때 구분자는 입력이 진짜인지 가짜인지를 무작위로 추측하게 됩니다. 하지만 GAN의 수렴 이론(convergence theory)에 대해서는 아직도 활발히 연구가 진행 중이며, 실제 모델들을 학습할 때에는 항상 이러한 이론적인 최적 상태에 도달하지는 못합니다.

그렇다면 DCGAN은 뭘까요?

DCGAN은 위에서 기술한 GAN에서 직접적으로 파생된 모델로, 생성자와 구분자에서 합성곱 신경망(convolution)과 전치 합성곱 신경망(convolution-transpose)을 사용했다는 것이 차이점입니다 Radford와 그 외가 저술한 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks 논문에서 처음 모델이 소개되었고, 지금은 대부분의 GAN모델이 DCGAN을 기반으로 만들어지는 중입니다. 이전 GAN과 모델의 구조가 실제로 어떻게 다른지 확인을 해보자면, 먼저 구분자에서는 convolution 계층, batch norm 계층, 그리고 LeakyReLU 활성함수가 사용되었습니다. 클래식한 GAN과 마찬가지로, 구분자의 입력 데이터는 3x64x64 의 이미지이고, 출력값은 입력 데이터가 실제 데이터일 0~1사이의 확률값입니다. 다음으로, 생성자는 convolutional-transpose 계층, 배치 정규화(batch norm) 계층, 그리고 ReLU 활성함수가 사용되었습니다. 입력값은 역시나 정규분포에서 추출한 잠재공간 벡터 \(z\) 이고, 출력값은 3x64x64 RGB 이미지입니다. 이 때, 전치 합성곱 계층(strided conv-transpose layer)은 잠재공간 벡터로 하여금 이미지와 같은 차원을 갖도록 변환시켜주는 역할을 합니다. (번역 주. 전치 합성곱 신경망은 합성곱 신경망의 반대적인 개념이라 이해하면 쉽습니다. 입력된 작은 CHW 데이터를 가중치들을 이용해 더 큰 CHW로 업샘플링해주는 계층입니다.) 논문에서는 각종 최적화 방법이나 손실함수의 계산, 모델의 가중치 초기화 방법등에 관한 추가적인 정보들도 적어두었는데, 이 부분은 다음 섹션에서 설명하도록 하겠습니다.

#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# 코드 실행결과의 동일성을 위해 무작위 시드를 설정합니다
manualSeed = 999
#manualSeed = random.randint(1, 10000) # 만일 새로운 결과를 원한다면 주석을 없애면 됩니다
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # 결과 재현을 위해 필요합니다
Random Seed:  999

설정 값

몇 가지 설정 값들을 살펴보겠습니다:

  • dataroot - 데이터셋 폴더의 경로입니다. 데이터셋에 대해서는 다음 섹션에서 더 자세히 설명하겠습니다.

  • workers - DataLoader 에서 데이터를 불러올 때 사용할 워커 쓰레드의 수입니다.

  • batch_size - 학습에 사용할 배치 크기입니다. DCGAN에서는 배치 크기를 128으로 사용했습니다.

  • image_size - 학습에 사용하는 이미지의 크기입니다. 이 튜토리얼에서는 64x64의 크기를 기본으로 하나, 만일 다른 크기의 이미지를 사용한다면 D와 G의 구조 또한 변경되어야 합니다. 이에 대해서는 여기 를 참고하여 더 자세한 정보를 확인할 수 있습니다.

  • nc - 입력 이미지의 색상의 채널 수입니다. RGB 컬러 이미지의 경우 이 값은 3입니다.

  • nz - 잠재공간 벡터의 원소들의 수입니다.

  • ngf - 생성자를 통과할 때 만들어질 특징 데이터의 채널 수입니다.

  • ndf - 구분자를 통과할 때 만들어질 특징 데이터의 채널 수입니다.

  • num_epochs - 학습시킬 에폭(epoch) 수입니다. 학습을 길게하는 경우 대부분 좋은 결과를 보이지만, 이러한 경우 시간 또한 오래 걸립니다.

  • lr - 모델의 학습률(learning rate)입니다. DCGAN 논문에서와 같이 0.0002로 설정합니다.

  • beta1 - Adam 옵티마이저에서 사용할 beta1 하이퍼파라미터 값입니다. 논문에서와 같이 0.5로 설정했습니다.

  • ngpu - 사용 가능한 GPU의 개수입니다. 0인 경우에는 코드는 CPU에서 동작합니다. 만약 이 값이 0보다 큰 경우에는 주어진 수 만큼의 GPU를 사용하여 학습을 진행합니다.

# 데이터셋의 경로
dataroot = "data/celeba"

# dataloader에서 사용할 쓰레드 수
workers = 2

# 배치 크기
batch_size = 128

# 이미지의 크기입니다. 모든 이미지를 변환하여 64로 크기가 통일됩니다.
image_size = 64

# 이미지의 채널 수로, RGB 이미지이기 때문에 3으로 설정합니다.
nc = 3

# 잠재공간 벡터의 크기 (예. 생성자의 입력값 크기)
nz = 100

# 생성자를 통과하는 특징 데이터들의 채널 크기
ngf = 64

# 구분자를 통과하는 특징 데이터들의 채널 크기
ndf = 64

# 학습할 에폭 수
num_epochs = 5

# 옵티마이저의 학습률
lr = 0.0002

# Adam 옵티마이저의 beta1 하이퍼파라미터
beta1 = 0.5

# 사용가능한 gpu 번호. CPU를 사용해야 하는경우 0으로 설정하세요
ngpu = 1

데이터

본 튜토리얼에서 사용할 데이터는 Celeb-A Faces dataset 로, 해당 링크를 이용하거나 Google Drive 에서 데이터를 받을 수 있습니다. 데이터를 받으면 img_align_celeba.zip 라는 파일을 보게될 겁니다. 다운로드가 끝나면 celeba 이라는 폴더를 새로 만들고, 해당 폴더에 해당 zip 파일을 압축해제 해주시면 됩니다. 압축 해제 후, 위에서 정의한 dataroot 변수에 방금 만든 celeba 폴더의 경로를 넣어주세요. 위의 작업이 끝나면 celeba 폴더의 구조는 다음과 같아야 합니다:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

이 과정들은 프로그램이 정상적으로 구동하기 위해서는 중요한 부분입니다. 이때 celeba 폴더 안에 다시 폴더를 두는 이유는, ImageFolder 클래스가 데이터셋의 최상위 폴더에 서브폴더를 요구하기 때문입니다. 이제 DatasetDataLoader 의 설정을 끝냈습니다. 최종적으로 학습 데이터들을 시각화해봅시다.

# 우리가 설정한 대로 이미지 데이터셋을 불러와 봅시다
# 먼저 데이터셋을 만듭니다
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# dataloader를 정의해봅시다
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# GPU 사용여부를 결정해 줍니다
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# 학습 데이터들 중 몇가지 이미지들을 화면에 띄워봅시다
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
Training Images

구현

모델의 설정값들과 데이터들이 준비되었기 때문에, 드디어 모델의 구현으로 들어갈 수 있을 것 같습니다. 먼저 가중치 초기화에 대해 이야기 해보고, 순서대로 생성자, 구분자, 손실 함수, 학습 방법들을 알아보겠습니다.

가중치 초기화

DCGAN 논문에서는, 평균이 0( mean=0 )이고 분산이 0.02( stdev=0.02 )인 정규분포을 시용해, 구분자와 생성자 모두 무작위 초기화를 진행하는 것이 좋다고 합니다. weights_init 함수는 매개변수로 모델을 입력받아, 모든 합성곱 계층, 전치 합성곱 계층, 배치 정규화 계층을, 위에서 말한 조건대로 가중치들을 다시 초기화 시킵니다. 이 함수는 모델이 만들어지자 마자 바로 적용을 시키게 됩니다.

# ``netG`` 와 ``netD`` 에 적용시킬 커스텀 가중치 초기화 함수
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

생성자

생성자 \(G\) 는 잠재 공간 벡터 \(z\) 를, 데이터 공간으로 변환시키도록 설계되었습니다. 우리에게 데이터라 함은 이미지이기 때문에, \(z\) 를 데이터공간으로 변환한다는 뜻은, 학습이미지와 같은 사이즈를 가진 RGB 이미지를 생성하는것과 같습니다 (예. 3x64x64). 실제 모델에서는 스트라이드(stride) 2를 가진 전치 합성곱 계층들을 이어서 구성하는데, 각 전치 합성곱 계층 하나당 2차원 배치 정규화 계층과 relu 활성함수를 한 쌍으로 묶어서 사용합니다. 생성자의 마지막 출력 계층에서는 데이터를 tanh 함수에 통과시키는데, 이는 출력 값을 \([-1,1]\) 사이의 범위로 조정하기 위해서 입니다. 이때 배치 정규화 계층을 주목할 필요가 있는데, DCGAN 논문에 의하면, 이 계층이 경사하강법(gradient-descent)의 흐름에 중요한 영향을 미치는 것으로 알려져 있습니다. 아래의 그림은 DCGAN 논문에서 가져온 생성자의 모델 아키텍쳐입니다.

dcgan_generator

우리가 설정값 섹션에서 정의한 값들이 (nz, ngf, 그리고 nc) 생성자 모델 아키텍쳐에 어떻게 영향을 끼치는지 주목해주세요. nz 는 z 입력 벡터의 길이, ngf 는 생성자를 통과하는 특징 데이터의 크기, 그리고 nc 는 출력 이미지의 채널 개수입니다 (RGB 이미지이기 때문에 3으로 설정을 했습니다). 아래는 생성자의 코드입니다.

# 생성자 코드

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # 입력데이터 Z가 가장 처음 통과하는 전치 합성곱 계층입니다.
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 위의 계층을 통과한 데이터의 크기. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 위의 계층을 통과한 데이터의 크기. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 위의 계층을 통과한 데이터의 크기. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 위의 계층을 통과한 데이터의 크기. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 위의 계층을 통과한 데이터의 크기. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

좋습니다. 이제 우리는 생성자의 인스턴스를 만들고 weights_init 함수를 적용시킬 수 있습니다. 모델의 인스턴스를 출력해서 생성자가 어떻게 구성되어있는지 확인해봅시다.

# 생성자를 만듭니다
netG = Generator(ngpu).to(device)

# 필요한 경우 multi-GPU를 설정 해주세요
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# 모든 가중치의 평균을 0( ``mean=0`` ), 분산을 0.02( ``stdev=0.02`` )로 초기화하기 위해
# ``weight_init`` 함수를 적용시킵니다
netG.apply(weights_init)

# 모델의 구조를 출력합니다
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

구분자

앞서 언급했듯, 구분자 \(D\) 는 입력 이미지가 진짜 이미지인지 (혹은 반대로 가짜 이미지인지) 판별하는 전통적인 이진 분류 신경망으로 볼 수 있습니다. 이때 \(D\) 는 3x64x64 이미지를 입력받아, Conv2d, BatchNorm2d, 그리고 LeakyReLU 계층을 통과시켜 데이터를 가공시키고, 마지막 출력에서 Sigmoid 함수를 이용하여 0~1 사이의 확률값으로 조정합니다. 이 아키텍쳐는 필요한 경우 더 다양한 레이어를 쌓을 수 있지만, 배치 정규화와 LeakyReLU, 특히 보폭이 있는 (strided) 합성곱 계층을 사용하는 것에는 이유가 있습니다. DCGAN 논문에서는 보폭이 있는 합성곱 계층을 사용하는 것이 신경망 내에서 스스로의 풀링(Pooling) 함수를 학습하기 때문에, 데이터를 처리하는 과정에서 직접적으로 풀링 계층( MaxPool or AvgPooling)을 사용하는 것보다 더 유리하다고 합니다. 또한 배치 정규화와 leaky relu 함수는 학습과정에서 \(G\)\(D\) 가 더 효과적인 경사도(gradient)를 얻을 수 있습니다.

# 구분자 코드

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # 입력 데이터의 크기는 ``(nc) x 64 x 64`` 입니다
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 위의 계층을 통과한 데이터의 크기. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 위의 계층을 통과한 데이터의 크기. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 위의 계층을 통과한 데이터의 크기. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 위의 계층을 통과한 데이터의 크기. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

이제 우리는 생성자에 한 것처럼 구분자의 인스턴스를 만들고, weights_init 함수를 적용시킨 다음, 모델의 구조를 출력해볼 수 있습니다.

# 구분자를 만듭니다
netD = Discriminator(ngpu).to(device)

# 필요한 경우 multi-GPU를 설정 해주세요
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# 모든 가중치의 평균을 0( ``mean=0`` ), 분산을 0.02( ``stdev=0.02`` )로 초기화하기 위해
# ``weight_init`` 함수를 적용시킵니다
netD.apply(weights_init)

# 모델의 구조를 출력합니다
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

손실함수와 옵티마이저

\(D\)\(G\) 의 설정을 끝냈으니, 이제 손실함수와 옵티마이저를 정하여 학습을 구체화시킬 시간입니다. 손실함수로는 Binary Cross Entropy loss (BCELoss) 를 사용할겁니다. 해당함수는 아래의 식으로 파이토치에 구현되어 있습니다:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] \]

이때, 위의 함수가 로그함수 요소를 정의한 방식을 주의깊게 봐주세요 (예. \(log(D(x))\)\(log(1-D(G(z)))\)). 우린 \(y\) 을 조정을 조정하여, BCE 함수에서 사용할 요소를 고를 수 있습니다. 이 부분은 이후에 서술할 학습 섹션에서 다루겠지만, 어떻게 \(y\) 를 이용하여 우리가 원하는 요소들만 골라낼 수 있는지 이해하는 것이 먼저입니다 (예. GT labels).

좋습니다. 다음으로 넘어가겠습니다. 참 라벨 (혹은 정답)은 1로 두고, 거짓 라벨 (혹은 오답)은 0으로 두겠습니다. 각 라벨의 값을 정한건 GAN 논문에서 사용된 값들로, GAN을 구성할때의 관례라 할 수 있습니다. 방금 정한 라벨 값들은 추후에 손실값을 계산하는 과정에서 사용될겁니다. 마지막으로, 서로 구분되는 두 옵티마이저를 구성하겠습니다. 하나는 \(D\) 를 위한 것, 다른 하나는 \(G\) 를 위한 것입니다. DCGAN에 서술된 대로, 두 옵티마이저는 모두 Adam을 사용하고, 학습률은 0.0002, Beta1 값은 0.5로 둡니다. 추가적으로, 학습이 진행되는 동안 생성자의 상태를 알아보기 위하여, 프로그램이 끝날때까지 고정된 잠재공간 벡터를 생성하겠습니다 (예. fixed_noise). 이 벡터들 역시 가우시안 분포에서 추출합니다. 학습 과정을 반복하면서 \(G\) 에 주기적으로 같은 잠재공간 벡터를 입력하면, 그 출력값을 기반으로 생성자의 상태를 확인 할 수 있습니다.

# ``BCELoss`` 함수의 인스턴스를 초기화합니다
criterion = nn.BCELoss()

# 생성자의 학습상태를 확인할 잠재 공간 벡터를 생성합니다
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# 학습에 사용되는 참/거짓의 라벨을 정합니다
real_label = 1.
fake_label = 0.

# G와 D에서 사용할 Adam옵티마이저를 생성합니다
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

학습

드디어 최종입니다. GAN 프레임워크에 필요한 부분들은 모두 가졌으니, 실제 모델을 학습시키는 방법을 알아보겠습니다. 주의를 기울일 것은, GAN을 학습시키는 건 관례적인 기술들의 집합이기 때문에, 잘못된 하이퍼파라미터의 설정은 모델의 학습을 망가뜨릴 수 있습니다. 무엇이 잘못되었는지 알아내는 것 조차도 힘들죠. 그러한 이유로, 본 튜토리얼에서는 Goodfellow’s paper 에서 서술된 Algorithm 1을 기반으로, ganhacks 에서 사용된 몇가지 괜찮은 테크닉들을 더할 것입니다. 앞서 몇번 설명했지만, 우리의 의도는 “진짜 혹은 가짜 이미지를 구성”하고, \(log(D(G(z)))\) 를 최대화하는 G의 목적함수를 최적화 시키는 겁니다. 학습과정은 크게 두가지로 나눕니다. Part 1은 구분자를, Part 2는 생성자를 업데이트하는 과정입니다.

Part 1 - 구분자의 학습

구분자의 목적은 주어진 입력값이 진짜인지 가짜인지 판별하는 것임을 상기합시다. Goodfellow의 말을 빌리자면, 구분자는 “변화도(gradient)를 상승(ascending)시키며 훈련”하게 됩니다. 실전적으로 얘기하면, \(log(D(x)) + log(1-D(G(z)))\) 를 최대화시키는 것과 같습니다. ganhacks 에서 미니 배치(mini-batch)를 분리하여 사용한 개념을 가져와서, 우리 역시 두가지 스텝으로 분리해 계산을 해보겠습니다. 먼저, 진짜 데이터들로만 이루어진 배치를 만들어 \(D\) 에 통과시킵니다. 그 출력값으로 (\(log(D(x))\)) 의 손실값을 계산하고, 역전파 과정에서의 변화도들을 계산합니다. 여기까지가 첫번째 스텝입니다. 두번째 스텝에서는, 오로지 가짜 데이터들로만 이루어진 배치를 만들어 \(D\) 에 통과시키고, 그 출력값으로 (\(log(1-D(G(z)))\)) 의 손실값을 계산해 역전파 변화도를 구하면 됩니다. 이때 두가지 스텝에서 나오는 변화도들은 축적(accumulate) 시켜야 합니다. 변화도까지 구했으니, 이제 옵티마이저를 사용해야겠죠. 파이토치의 함수를 호출해주면 알아서 변화도가 적용될겁니다.

Part 2 - 생성자의 학습

오리지널 GAN 논문에 명시되어 있듯, 생성자는 \(log(1-D(G(z)))\) 을 최소화시키는 방향으로 학습합니다. 하지만 이 방식은 충분한 변화도를 제공하지 못함을 Goodfellow가 보여줬습니다. 특히 학습초기에는 더욱 문제를 일으키죠. 이를 해결하기 위해 \(log(D(G(z)))\) 를 최대화 하는 방식으로 바꿔서 학습을 하겠습니다. 코드에서 구현하기 위해서는 : Part 1에서 한대로 구분자를 이용해 생성자의 출력값을 판별해주고, 진짜 라벨값 을 이용해 G의 손실값을 구해줍니다. 그러면 구해진 손실값으로 변화도를 구하고, 최종적으로는 옵티마이저를 이용해 G의 가중치들을 업데이트시켜주면 됩니다. 언뜻 볼때는, 생성자가 만들어낸 가짜 이미지에 진짜 라벨을 사용하는것이 직관적으로 위배가 될테지만, 이렇게 라벨을 바꿈으로써 \(log(x)\) 라는 BCELoss 의 일부분을 사용할 수 있게 합니다 (앞서 우리는 BCELoss에서 라벨을 이용해 원하는 로그 계산 요소를 고를 수 있음을 알아봤습니다).

마무리로 G의 훈련 상태를 알아보기 위하여, 몇가지 통계적인 수치들과, fixed_noise를 통과시킨 결과를 화면에 출력하는 코드를 추가하겠습니다. 이때 통계적인 수치들이라 함은:

  • Loss_D - 진짜 데이터와 가짜 데이터들 모두에서 구해진 손실값. (\(log(D(x)) + log(1 - D(G(z)))\)).

  • Loss_G - 생성자의 손실값. \(log(D(G(z)))\)

  • D(x) - 구분자가 데이터를 판별한 확률값입니다. 처음에는 1에 가까운 값이다가, G가 학습할수록 0.5값에 수렴하게 됩니다.

  • D(G(z)) - 가짜데이터들에 대한 구분자의 출력값입니다. 처음에는 0에 가까운 값이다가, G가 학습할수록 0.5에 수렴하게 됩니다

Note: 이후의 과정은 epoch의 수와 데이터의 수에 따라 시간이 좀 걸릴 수 있습니다

# 학습 과정

# 학습상태를 체크하기 위해 손실값들을 저장합니다
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# 에폭(epoch) 반복
for epoch in range(num_epochs):
    # 한 에폭 내에서 배치 반복
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) D 신경망을 업데이트 합니다: log(D(x)) + log(1 - D(G(z)))를 최대화 합니다
        ###########################
        ## 진짜 데이터들로 학습을 합니다
        netD.zero_grad()
        # 배치들의 사이즈나 사용할 디바이스에 맞게 조정합니다
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label,
                           dtype=torch.float, device=device)
        # 진짜 데이터들로 이루어진 배치를 D에 통과시킵니다
        output = netD(real_cpu).view(-1)
        # 손실값을 구합니다
        errD_real = criterion(output, label)
        # 역전파의 과정에서 변화도를 계산합니다
        errD_real.backward()
        D_x = output.mean().item()

        ## 가짜 데이터들로 학습을 합니다
        # 생성자에 사용할 잠재공간 벡터를 생성합니다
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # G를 이용해 가짜 이미지를 생성합니다
        fake = netG(noise)
        label.fill_(fake_label)
        # D를 이용해 데이터의 진위를 판별합니다
        output = netD(fake.detach()).view(-1)
        # D의 손실값을 계산합니다
        errD_fake = criterion(output, label)
        # 역전파를 통해 변화도를 계산합니다. 이때 앞서 구한 변화도에 더합니다(accumulate)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # 가짜 이미지와 진짜 이미지 모두에서 구한 손실값들을 더합니다
        # 이때 errD는 역전파에서 사용되지 않고, 이후 학습 상태를 리포팅(reporting)할 때 사용합니다
        errD = errD_real + errD_fake
        # D를 업데이트 합니다
        optimizerD.step()

        ############################
        # (2) G 신경망을 업데이트 합니다: log(D(G(z)))를 최대화 합니다
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # 생성자의 손실값을 구하기 위해 진짜 라벨을 이용할 겁니다
        # 우리는 방금 D를 업데이트했기 때문에, D에 다시 가짜 데이터를 통과시킵니다.
        # 이때 G는 업데이트되지 않았지만, D가 업데이트 되었기 때문에 앞선 손실값가 다른 값이 나오게 됩니다
        output = netD(fake).view(-1)
        # G의 손실값을 구합니다
        errG = criterion(output, label)
        # G의 변화도를 계산합니다
        errG.backward()
        D_G_z2 = output.mean().item()
        # G를 업데이트 합니다
        optimizerG.step()

        # 훈련 상태를 출력합니다
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # 이후 그래프를 그리기 위해 손실값들을 저장해둡니다
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # fixed_noise를 통과시킨 G의 출력값을 저장해둡니다
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.8299  Loss_G: 6.4844  D(x): 0.6585    D(G(z)): 0.6694 / 0.0028
[0/5][50/1583]  Loss_D: 0.0216  Loss_G: 33.2432 D(x): 0.9869    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.3433  Loss_G: 7.8246  D(x): 0.9662    D(G(z)): 0.2164 / 0.0012
[0/5][150/1583] Loss_D: 1.0695  Loss_G: 11.0265 D(x): 0.9838    D(G(z)): 0.5948 / 0.0001
[0/5][200/1583] Loss_D: 0.4283  Loss_G: 6.5706  D(x): 0.8856    D(G(z)): 0.2199 / 0.0024
[0/5][250/1583] Loss_D: 0.2737  Loss_G: 4.5078  D(x): 0.8721    D(G(z)): 0.0598 / 0.0273
[0/5][300/1583] Loss_D: 0.4181  Loss_G: 3.8659  D(x): 0.7803    D(G(z)): 0.0764 / 0.0387
[0/5][350/1583] Loss_D: 0.5303  Loss_G: 3.3211  D(x): 0.9255    D(G(z)): 0.3102 / 0.0530
[0/5][400/1583] Loss_D: 0.7111  Loss_G: 7.4542  D(x): 0.9197    D(G(z)): 0.4159 / 0.0017
[0/5][450/1583] Loss_D: 0.8862  Loss_G: 6.2063  D(x): 0.9006    D(G(z)): 0.4330 / 0.0078
[0/5][500/1583] Loss_D: 0.6696  Loss_G: 3.4424  D(x): 0.6466    D(G(z)): 0.1162 / 0.0507
[0/5][550/1583] Loss_D: 0.2583  Loss_G: 5.4076  D(x): 0.9120    D(G(z)): 0.1254 / 0.0088
[0/5][600/1583] Loss_D: 4.6331  Loss_G: 8.0574  D(x): 0.1101    D(G(z)): 0.0023 / 0.0100
[0/5][650/1583] Loss_D: 0.9090  Loss_G: 9.1068  D(x): 0.9169    D(G(z)): 0.4637 / 0.0004
[0/5][700/1583] Loss_D: 0.4734  Loss_G: 3.7394  D(x): 0.7394    D(G(z)): 0.0469 / 0.0463
[0/5][750/1583] Loss_D: 0.4682  Loss_G: 3.4495  D(x): 0.7475    D(G(z)): 0.0644 / 0.0568
[0/5][800/1583] Loss_D: 0.3883  Loss_G: 4.7214  D(x): 0.8470    D(G(z)): 0.1477 / 0.0175
[0/5][850/1583] Loss_D: 0.4575  Loss_G: 3.5411  D(x): 0.7866    D(G(z)): 0.1385 / 0.0479
[0/5][900/1583] Loss_D: 0.6576  Loss_G: 2.5494  D(x): 0.6231    D(G(z)): 0.0393 / 0.1290
[0/5][950/1583] Loss_D: 0.6311  Loss_G: 3.6275  D(x): 0.7112    D(G(z)): 0.1132 / 0.0599
[0/5][1000/1583]        Loss_D: 0.2676  Loss_G: 4.6449  D(x): 0.8513    D(G(z)): 0.0672 / 0.0182
[0/5][1050/1583]        Loss_D: 0.7362  Loss_G: 1.7370  D(x): 0.6617    D(G(z)): 0.1551 / 0.2599
[0/5][1100/1583]        Loss_D: 0.6021  Loss_G: 4.9925  D(x): 0.8432    D(G(z)): 0.2748 / 0.0130
[0/5][1150/1583]        Loss_D: 1.3287  Loss_G: 8.2732  D(x): 0.8937    D(G(z)): 0.6096 / 0.0009
[0/5][1200/1583]        Loss_D: 0.9292  Loss_G: 2.7833  D(x): 0.5178    D(G(z)): 0.0162 / 0.1025
[0/5][1250/1583]        Loss_D: 0.4783  Loss_G: 4.5441  D(x): 0.8840    D(G(z)): 0.2438 / 0.0171
[0/5][1300/1583]        Loss_D: 0.5730  Loss_G: 5.0385  D(x): 0.9025    D(G(z)): 0.2919 / 0.0149
[0/5][1350/1583]        Loss_D: 0.4775  Loss_G: 2.6882  D(x): 0.7436    D(G(z)): 0.1013 / 0.0999
[0/5][1400/1583]        Loss_D: 1.3429  Loss_G: 1.9566  D(x): 0.4023    D(G(z)): 0.0206 / 0.2382
[0/5][1450/1583]        Loss_D: 0.5661  Loss_G: 3.3978  D(x): 0.7607    D(G(z)): 0.1949 / 0.0579
[0/5][1500/1583]        Loss_D: 0.9623  Loss_G: 6.4973  D(x): 0.8771    D(G(z)): 0.4471 / 0.0042
[0/5][1550/1583]        Loss_D: 0.5302  Loss_G: 2.6180  D(x): 0.7400    D(G(z)): 0.1364 / 0.1116
[1/5][0/1583]   Loss_D: 1.7302  Loss_G: 7.7747  D(x): 0.9382    D(G(z)): 0.7493 / 0.0014
[1/5][50/1583]  Loss_D: 0.4628  Loss_G: 4.7022  D(x): 0.8964    D(G(z)): 0.2541 / 0.0181
[1/5][100/1583] Loss_D: 0.5329  Loss_G: 4.4852  D(x): 0.8343    D(G(z)): 0.2569 / 0.0186
[1/5][150/1583] Loss_D: 0.6291  Loss_G: 4.9310  D(x): 0.8720    D(G(z)): 0.3368 / 0.0124
[1/5][200/1583] Loss_D: 0.3658  Loss_G: 3.0937  D(x): 0.8020    D(G(z)): 0.1003 / 0.0718
[1/5][250/1583] Loss_D: 0.5866  Loss_G: 5.0687  D(x): 0.9359    D(G(z)): 0.3344 / 0.0157
[1/5][300/1583] Loss_D: 0.8897  Loss_G: 5.6717  D(x): 0.9146    D(G(z)): 0.4821 / 0.0060
[1/5][350/1583] Loss_D: 0.2999  Loss_G: 4.0035  D(x): 0.9047    D(G(z)): 0.1563 / 0.0279
[1/5][400/1583] Loss_D: 0.5837  Loss_G: 6.5790  D(x): 0.8611    D(G(z)): 0.2864 / 0.0030
[1/5][450/1583] Loss_D: 0.5315  Loss_G: 5.3788  D(x): 0.9588    D(G(z)): 0.3365 / 0.0088
[1/5][500/1583] Loss_D: 0.5127  Loss_G: 4.7425  D(x): 0.9434    D(G(z)): 0.3224 / 0.0178
[1/5][550/1583] Loss_D: 0.4125  Loss_G: 3.3058  D(x): 0.8001    D(G(z)): 0.1330 / 0.0555
[1/5][600/1583] Loss_D: 0.7605  Loss_G: 2.0931  D(x): 0.5807    D(G(z)): 0.0348 / 0.1766
[1/5][650/1583] Loss_D: 0.3706  Loss_G: 4.0916  D(x): 0.9049    D(G(z)): 0.2018 / 0.0268
[1/5][700/1583] Loss_D: 1.0406  Loss_G: 2.1585  D(x): 0.4803    D(G(z)): 0.0651 / 0.1869
[1/5][750/1583] Loss_D: 0.9149  Loss_G: 2.5708  D(x): 0.4985    D(G(z)): 0.0303 / 0.1209
[1/5][800/1583] Loss_D: 0.4332  Loss_G: 5.0259  D(x): 0.9286    D(G(z)): 0.2666 / 0.0109
[1/5][850/1583] Loss_D: 0.4627  Loss_G: 3.3560  D(x): 0.7649    D(G(z)): 0.1155 / 0.0621
[1/5][900/1583] Loss_D: 0.6375  Loss_G: 3.4163  D(x): 0.7703    D(G(z)): 0.2535 / 0.0549
[1/5][950/1583] Loss_D: 0.9569  Loss_G: 5.3410  D(x): 0.9114    D(G(z)): 0.5052 / 0.0086
[1/5][1000/1583]        Loss_D: 0.3181  Loss_G: 3.3453  D(x): 0.8359    D(G(z)): 0.1109 / 0.0536
[1/5][1050/1583]        Loss_D: 2.0014  Loss_G: 0.1647  D(x): 0.2356    D(G(z)): 0.0123 / 0.8661
[1/5][1100/1583]        Loss_D: 0.5862  Loss_G: 1.8887  D(x): 0.6958    D(G(z)): 0.1310 / 0.1968
[1/5][1150/1583]        Loss_D: 0.4217  Loss_G: 2.4708  D(x): 0.8179    D(G(z)): 0.1656 / 0.1093
[1/5][1200/1583]        Loss_D: 0.3932  Loss_G: 3.4813  D(x): 0.8218    D(G(z)): 0.1450 / 0.0512
[1/5][1250/1583]        Loss_D: 0.4992  Loss_G: 2.7227  D(x): 0.7986    D(G(z)): 0.1896 / 0.0925
[1/5][1300/1583]        Loss_D: 2.0641  Loss_G: 9.2162  D(x): 0.9834    D(G(z)): 0.8083 / 0.0004
[1/5][1350/1583]        Loss_D: 0.7828  Loss_G: 1.6472  D(x): 0.5522    D(G(z)): 0.0500 / 0.2422
[1/5][1400/1583]        Loss_D: 0.8685  Loss_G: 3.2494  D(x): 0.7671    D(G(z)): 0.3577 / 0.0672
[1/5][1450/1583]        Loss_D: 0.4420  Loss_G: 3.1004  D(x): 0.7554    D(G(z)): 0.1068 / 0.0721
[1/5][1500/1583]        Loss_D: 0.5182  Loss_G: 1.7615  D(x): 0.8111    D(G(z)): 0.1945 / 0.2513
[1/5][1550/1583]        Loss_D: 0.5423  Loss_G: 3.0761  D(x): 0.7869    D(G(z)): 0.2183 / 0.0650
[2/5][0/1583]   Loss_D: 0.4555  Loss_G: 2.7615  D(x): 0.7995    D(G(z)): 0.1666 / 0.0869
[2/5][50/1583]  Loss_D: 0.5835  Loss_G: 3.5953  D(x): 0.8609    D(G(z)): 0.3045 / 0.0399
[2/5][100/1583] Loss_D: 0.6120  Loss_G: 4.5291  D(x): 0.9432    D(G(z)): 0.3849 / 0.0164
[2/5][150/1583] Loss_D: 0.6065  Loss_G: 4.4620  D(x): 0.8814    D(G(z)): 0.3448 / 0.0190
[2/5][200/1583] Loss_D: 0.6349  Loss_G: 2.6065  D(x): 0.8119    D(G(z)): 0.2904 / 0.1027
[2/5][250/1583] Loss_D: 0.5720  Loss_G: 3.0309  D(x): 0.8024    D(G(z)): 0.2585 / 0.0674
[2/5][300/1583] Loss_D: 0.5026  Loss_G: 3.0147  D(x): 0.8588    D(G(z)): 0.2623 / 0.0650
[2/5][350/1583] Loss_D: 0.5240  Loss_G: 2.5972  D(x): 0.8285    D(G(z)): 0.2501 / 0.1012
[2/5][400/1583] Loss_D: 0.8058  Loss_G: 3.7930  D(x): 0.7832    D(G(z)): 0.3701 / 0.0349
[2/5][450/1583] Loss_D: 0.5139  Loss_G: 3.4010  D(x): 0.9135    D(G(z)): 0.3196 / 0.0456
[2/5][500/1583] Loss_D: 1.0348  Loss_G: 2.2396  D(x): 0.5974    D(G(z)): 0.2844 / 0.1645
[2/5][550/1583] Loss_D: 0.5964  Loss_G: 3.5267  D(x): 0.8761    D(G(z)): 0.3248 / 0.0423
[2/5][600/1583] Loss_D: 0.4706  Loss_G: 2.1557  D(x): 0.7317    D(G(z)): 0.0947 / 0.1494
[2/5][650/1583] Loss_D: 0.5109  Loss_G: 2.2192  D(x): 0.8313    D(G(z)): 0.2457 / 0.1458
[2/5][700/1583] Loss_D: 0.5387  Loss_G: 3.2269  D(x): 0.8020    D(G(z)): 0.2388 / 0.0569
[2/5][750/1583] Loss_D: 0.8508  Loss_G: 2.7999  D(x): 0.7463    D(G(z)): 0.3702 / 0.0804
[2/5][800/1583] Loss_D: 0.7991  Loss_G: 4.3823  D(x): 0.8847    D(G(z)): 0.4408 / 0.0202
[2/5][850/1583] Loss_D: 0.4814  Loss_G: 1.9543  D(x): 0.7200    D(G(z)): 0.1047 / 0.1674
[2/5][900/1583] Loss_D: 0.6199  Loss_G: 2.7804  D(x): 0.8462    D(G(z)): 0.3158 / 0.0880
[2/5][950/1583] Loss_D: 0.5574  Loss_G: 2.1244  D(x): 0.7050    D(G(z)): 0.1442 / 0.1542
[2/5][1000/1583]        Loss_D: 0.8899  Loss_G: 0.8985  D(x): 0.5147    D(G(z)): 0.0914 / 0.4514
[2/5][1050/1583]        Loss_D: 0.6494  Loss_G: 2.7185  D(x): 0.7583    D(G(z)): 0.2650 / 0.0885
[2/5][1100/1583]        Loss_D: 0.6650  Loss_G: 3.3982  D(x): 0.9368    D(G(z)): 0.3988 / 0.0493
[2/5][1150/1583]        Loss_D: 0.5672  Loss_G: 3.1718  D(x): 0.8599    D(G(z)): 0.3025 / 0.0561
[2/5][1200/1583]        Loss_D: 0.5750  Loss_G: 1.9404  D(x): 0.6829    D(G(z)): 0.1250 / 0.1863
[2/5][1250/1583]        Loss_D: 0.5117  Loss_G: 2.6210  D(x): 0.8118    D(G(z)): 0.2297 / 0.0985
[2/5][1300/1583]        Loss_D: 0.9280  Loss_G: 3.8165  D(x): 0.8545    D(G(z)): 0.4870 / 0.0329
[2/5][1350/1583]        Loss_D: 0.5364  Loss_G: 1.7776  D(x): 0.7241    D(G(z)): 0.1515 / 0.2089
[2/5][1400/1583]        Loss_D: 0.5405  Loss_G: 2.7218  D(x): 0.8620    D(G(z)): 0.2881 / 0.0893
[2/5][1450/1583]        Loss_D: 0.5621  Loss_G: 1.8932  D(x): 0.6670    D(G(z)): 0.1063 / 0.1854
[2/5][1500/1583]        Loss_D: 0.6461  Loss_G: 2.8116  D(x): 0.8105    D(G(z)): 0.3209 / 0.0756
[2/5][1550/1583]        Loss_D: 1.4484  Loss_G: 3.9103  D(x): 0.9190    D(G(z)): 0.6744 / 0.0335
[3/5][0/1583]   Loss_D: 0.7030  Loss_G: 1.9966  D(x): 0.6745    D(G(z)): 0.2117 / 0.1653
[3/5][50/1583]  Loss_D: 0.6820  Loss_G: 4.1527  D(x): 0.8829    D(G(z)): 0.3859 / 0.0227
[3/5][100/1583] Loss_D: 0.6641  Loss_G: 1.6542  D(x): 0.6526    D(G(z)): 0.1316 / 0.2381
[3/5][150/1583] Loss_D: 0.9852  Loss_G: 1.1994  D(x): 0.4490    D(G(z)): 0.0487 / 0.3598
[3/5][200/1583] Loss_D: 0.7155  Loss_G: 2.5735  D(x): 0.7670    D(G(z)): 0.3191 / 0.0981
[3/5][250/1583] Loss_D: 0.9242  Loss_G: 0.5643  D(x): 0.4705    D(G(z)): 0.0573 / 0.5965
[3/5][300/1583] Loss_D: 1.3664  Loss_G: 3.6245  D(x): 0.9150    D(G(z)): 0.6430 / 0.0427
[3/5][350/1583] Loss_D: 0.7057  Loss_G: 1.7950  D(x): 0.6912    D(G(z)): 0.2298 / 0.2066
[3/5][400/1583] Loss_D: 1.0043  Loss_G: 0.7726  D(x): 0.4385    D(G(z)): 0.0198 / 0.5052
[3/5][450/1583] Loss_D: 0.8219  Loss_G: 4.0062  D(x): 0.8956    D(G(z)): 0.4650 / 0.0255
[3/5][500/1583] Loss_D: 1.4533  Loss_G: 5.0873  D(x): 0.9550    D(G(z)): 0.7052 / 0.0100
[3/5][550/1583] Loss_D: 0.5347  Loss_G: 2.3179  D(x): 0.8256    D(G(z)): 0.2587 / 0.1266
[3/5][600/1583] Loss_D: 0.7652  Loss_G: 2.2733  D(x): 0.6233    D(G(z)): 0.1874 / 0.1405
[3/5][650/1583] Loss_D: 0.5741  Loss_G: 1.7345  D(x): 0.6381    D(G(z)): 0.0628 / 0.2171
[3/5][700/1583] Loss_D: 0.8164  Loss_G: 1.3518  D(x): 0.5420    D(G(z)): 0.1088 / 0.3113
[3/5][750/1583] Loss_D: 0.6063  Loss_G: 2.3610  D(x): 0.7897    D(G(z)): 0.2655 / 0.1197
[3/5][800/1583] Loss_D: 0.5658  Loss_G: 2.2550  D(x): 0.7682    D(G(z)): 0.2262 / 0.1310
[3/5][850/1583] Loss_D: 0.5685  Loss_G: 2.2438  D(x): 0.7491    D(G(z)): 0.2049 / 0.1318
[3/5][900/1583] Loss_D: 0.7884  Loss_G: 4.5239  D(x): 0.9341    D(G(z)): 0.4602 / 0.0166
[3/5][950/1583] Loss_D: 0.8442  Loss_G: 1.0097  D(x): 0.5230    D(G(z)): 0.0899 / 0.4212
[3/5][1000/1583]        Loss_D: 1.4648  Loss_G: 0.9003  D(x): 0.3213    D(G(z)): 0.0587 / 0.4914
[3/5][1050/1583]        Loss_D: 0.7254  Loss_G: 1.2147  D(x): 0.6394    D(G(z)): 0.1858 / 0.3498
[3/5][1100/1583]        Loss_D: 0.5572  Loss_G: 2.4531  D(x): 0.7870    D(G(z)): 0.2467 / 0.1112
[3/5][1150/1583]        Loss_D: 0.9237  Loss_G: 0.9882  D(x): 0.4921    D(G(z)): 0.0812 / 0.4370
[3/5][1200/1583]        Loss_D: 0.6010  Loss_G: 1.5046  D(x): 0.6975    D(G(z)): 0.1673 / 0.2655
[3/5][1250/1583]        Loss_D: 0.7804  Loss_G: 3.4338  D(x): 0.8536    D(G(z)): 0.4181 / 0.0428
[3/5][1300/1583]        Loss_D: 0.5154  Loss_G: 2.6825  D(x): 0.8009    D(G(z)): 0.2234 / 0.0880
[3/5][1350/1583]        Loss_D: 0.7472  Loss_G: 1.0414  D(x): 0.5534    D(G(z)): 0.0797 / 0.4004
[3/5][1400/1583]        Loss_D: 1.0085  Loss_G: 1.1159  D(x): 0.4628    D(G(z)): 0.0725 / 0.3859
[3/5][1450/1583]        Loss_D: 0.5658  Loss_G: 1.8971  D(x): 0.7118    D(G(z)): 0.1645 / 0.1795
[3/5][1500/1583]        Loss_D: 0.6255  Loss_G: 2.5736  D(x): 0.7579    D(G(z)): 0.2610 / 0.0964
[3/5][1550/1583]        Loss_D: 0.5519  Loss_G: 2.5823  D(x): 0.7956    D(G(z)): 0.2443 / 0.0969
[4/5][0/1583]   Loss_D: 0.6206  Loss_G: 2.2027  D(x): 0.6095    D(G(z)): 0.0477 / 0.1542
[4/5][50/1583]  Loss_D: 0.6424  Loss_G: 2.4764  D(x): 0.7730    D(G(z)): 0.2748 / 0.1126
[4/5][100/1583] Loss_D: 0.7422  Loss_G: 1.9242  D(x): 0.5594    D(G(z)): 0.0494 / 0.1917
[4/5][150/1583] Loss_D: 0.4719  Loss_G: 3.5467  D(x): 0.8632    D(G(z)): 0.2523 / 0.0387
[4/5][200/1583] Loss_D: 0.6210  Loss_G: 3.3302  D(x): 0.8217    D(G(z)): 0.3070 / 0.0499
[4/5][250/1583] Loss_D: 0.8848  Loss_G: 3.7179  D(x): 0.8995    D(G(z)): 0.4866 / 0.0379
[4/5][300/1583] Loss_D: 0.6426  Loss_G: 1.6399  D(x): 0.6463    D(G(z)): 0.1357 / 0.2307
[4/5][350/1583] Loss_D: 0.7443  Loss_G: 3.6400  D(x): 0.8688    D(G(z)): 0.3996 / 0.0372
[4/5][400/1583] Loss_D: 0.6992  Loss_G: 1.5366  D(x): 0.5837    D(G(z)): 0.0760 / 0.2606
[4/5][450/1583] Loss_D: 0.5514  Loss_G: 2.3351  D(x): 0.8079    D(G(z)): 0.2550 / 0.1184
[4/5][500/1583] Loss_D: 0.9832  Loss_G: 0.9129  D(x): 0.4597    D(G(z)): 0.0474 / 0.4450
[4/5][550/1583] Loss_D: 0.9030  Loss_G: 1.2620  D(x): 0.5228    D(G(z)): 0.1259 / 0.3395
[4/5][600/1583] Loss_D: 1.2299  Loss_G: 4.7300  D(x): 0.9337    D(G(z)): 0.6157 / 0.0140
[4/5][650/1583] Loss_D: 0.5246  Loss_G: 2.1015  D(x): 0.7479    D(G(z)): 0.1719 / 0.1489
[4/5][700/1583] Loss_D: 0.6665  Loss_G: 2.6220  D(x): 0.7907    D(G(z)): 0.3138 / 0.0923
[4/5][750/1583] Loss_D: 0.6834  Loss_G: 2.9899  D(x): 0.8768    D(G(z)): 0.3822 / 0.0706
[4/5][800/1583] Loss_D: 0.6991  Loss_G: 1.8664  D(x): 0.5872    D(G(z)): 0.0748 / 0.2072
[4/5][850/1583] Loss_D: 0.5925  Loss_G: 2.9061  D(x): 0.7768    D(G(z)): 0.2429 / 0.0785
[4/5][900/1583] Loss_D: 0.4953  Loss_G: 2.3709  D(x): 0.7916    D(G(z)): 0.1964 / 0.1232
[4/5][950/1583] Loss_D: 0.5543  Loss_G: 2.4166  D(x): 0.7826    D(G(z)): 0.2301 / 0.1174
[4/5][1000/1583]        Loss_D: 0.4875  Loss_G: 2.5188  D(x): 0.8052    D(G(z)): 0.2107 / 0.1046
[4/5][1050/1583]        Loss_D: 0.4684  Loss_G: 2.7462  D(x): 0.8045    D(G(z)): 0.1911 / 0.0801
[4/5][1100/1583]        Loss_D: 0.3519  Loss_G: 2.4601  D(x): 0.8582    D(G(z)): 0.1649 / 0.1038
[4/5][1150/1583]        Loss_D: 0.6213  Loss_G: 2.2838  D(x): 0.7539    D(G(z)): 0.2399 / 0.1361
[4/5][1200/1583]        Loss_D: 0.7621  Loss_G: 3.3453  D(x): 0.9571    D(G(z)): 0.4498 / 0.0577
[4/5][1250/1583]        Loss_D: 0.4809  Loss_G: 3.2195  D(x): 0.8985    D(G(z)): 0.2862 / 0.0507
[4/5][1300/1583]        Loss_D: 1.2771  Loss_G: 5.6843  D(x): 0.9838    D(G(z)): 0.6521 / 0.0053
[4/5][1350/1583]        Loss_D: 0.6455  Loss_G: 2.8243  D(x): 0.8675    D(G(z)): 0.3457 / 0.0792
[4/5][1400/1583]        Loss_D: 1.8702  Loss_G: 5.3888  D(x): 0.9753    D(G(z)): 0.7787 / 0.0084
[4/5][1450/1583]        Loss_D: 1.6706  Loss_G: 3.6516  D(x): 0.9304    D(G(z)): 0.7298 / 0.0420
[4/5][1500/1583]        Loss_D: 0.4382  Loss_G: 2.4364  D(x): 0.8277    D(G(z)): 0.1912 / 0.1128
[4/5][1550/1583]        Loss_D: 0.6895  Loss_G: 1.4018  D(x): 0.5896    D(G(z)): 0.0819 / 0.3066

결과

결과를 알아봅시다. 이 섹션에서는 총 세가지를 확인할겁니다. 첫번째는 G와 D의 손실값들이 어떻게 변했는가, 두번째는 매 에폭마다 fixed_noise를 이용해 G가 만들어낸 이미지들, 마지막은 학습이 끝난 G가 만들어낸 이미지와 진짜 이미지들의 비교입니다

학습하는 동안의 손실값들

아래는 D와 G의 손실값들을 그래프로 그린 모습입니다

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
Generator and Discriminator Loss During Training

G의 학습 과정 시각화

매 에폭마다 fixed_noise를 이용해 생성자가 만들어낸 이미지를 저장한 것을 기억할겁니다. 저장한 이미지들을애니메이션 형식으로 확인해 봅시다. play버튼을 누르면 애니매이션이 실행됩니다

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
dcgan faces tutorial


진짜 이미지 vs. 가짜 이미지

진짜 이미지들과 가짜 이미지들을 옆으로 두고 비교를 해봅시다

# dataloader에서 진짜 데이터들을 가져옵니다
real_batch = next(iter(dataloader))

# 진짜 이미지들을 화면에 출력합니다
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 가짜 이미지들을 화면에 출력합니다
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
Real Images, Fake Images

이제 어디로 여행을 떠나볼까요?

드디어 DCGAN이 끝났습니다! 하지만 더 알아볼 것들이 많이 남아있죠. 무엇을 더 시도해볼 수 있을까요?

  • 결과물이 얼마나 더 좋아지는지 확인해보기 위해서 학습시간을 늘려볼 수 있습니다

  • 다른 데이터셋을 이용해 훈련시켜보거나, 이미지의 사이즈를 다르게 해보거나, 아키텍쳐의 구성을 바꿔볼 수도 있습니다

  • 여기 에서 더욱 멋진 GAN 프로젝트들을 찾을수도 있죠

  • 음악 을 작곡하는 GAN도 만들 수 있습니다

Total running time of the script: ( 4 minutes 54.769 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동