Shortcuts

PyTorch: 제어 흐름(Control Flow) + 가중치 공유(Weight Sharing)

PyTorch 동적 그래프의 강력함을 보여주기 위해, 매우 이상한 모델을 구현해보겠습니다: 각 순전파 단계에서 4 ~ 5 사이의 임의의 숫자를 선택하여 다차항들에서 사용하고, 동일한 가중치를 여러번 재사용하여 4차항과 5차항을 계산하는 3-5차 다항식입니다.

import random
import torch
import math


class DynamicNet(torch.nn.Module):
    def __init__(self):
        """
        생성자에서 5개의 매개변수를 생성(instantiate)하고 멤버 변수로 지정합니다.
        """
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(()))
        self.b = torch.nn.Parameter(torch.randn(()))
        self.c = torch.nn.Parameter(torch.randn(()))
        self.d = torch.nn.Parameter(torch.randn(()))
        self.e = torch.nn.Parameter(torch.randn(()))

    def forward(self, x):
        """
        모델의 순전파 단계에서는 무작위로 4, 5 중 하나를 선택한 뒤 매개변수 e를 재사용하여
        이 차수들의의 기여도(contribution)를 계산합니다.

        각 순전파 단계는 동적 연산 그래프를 구성하기 때문에, 모델의 순전파 단계를 정의할 때
        반복문이나 조건문과 같은 일반적인 Python 제어-흐름 연산자를 사용할 수 있습니다.

        여기에서 연산 그래프를 정의할 때 동일한 매개변수를 여러번 사용하는 것이 완벽히 안전하다는
        것을 알 수 있습니다.
        """
        y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
        for exp in range(4, random.randint(4, 6)):
            y = y + self.e * x ** exp
        return y

    def string(self):
        """
        Python의 다른 클래스(class)처럼, PyTorch 모듈을 사용해서 사용자 정의 메소드를 정의할 수 있습니다.
        """
        return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'


# 입력값과 출력값을 갖는 텐서들을 생성합니다.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# 위에서 정의한 클래스로 모델을 생성합니다.
model = DynamicNet()

# 손실 함수와 optimizer를 생성합니다. 이 이상한 모델을 순수한 확률적 경사하강법(SGD; Stochastic Gradient Descent)으로
# 학습하는 것은 어려우므로, 모멘텀(momentum)을 사용합니다.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
for t in range(30000):
    # 순전파 단계: 모델에 x를 전달하여 예측값 y를 계산합니다.
    y_pred = model(x)

    # 손실을 계산하고 출력합니다.
    loss = criterion(y_pred, y)
    if t % 2000 == 1999:
        print(t, loss.item())

    # 변화도를 0으로 만들고, 역전파 단계를 수행하고, 가중치를 갱신합니다.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f'Result: {model.string()}')

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동