Shortcuts

준비 운동: NumPy

\(y=\sin(x)\) 을 예측할 수 있도록, \(-\pi\) 부터 \(pi\) 까지 유클리드 거리(Euclidean distance)를 최소화하도록 3차 다항식을 학습합니다.

이 구현은 NumPy를 사용하여 순전파 단계와 손실(loss), 역전파 단계를 직접 계산합니다.

NumPy 배열은 일반적인 n-차원 배열로, 딥러닝이나 변화도(gradient), 연산 그래프(computational graph)는 알지 못하며 일반적인 수치 연산을 수행합니다.

99 500.7650889686714
199 339.1473461575507
299 230.8076343607912
399 158.12735053874192
499 109.33109456243791
599 76.54344183711746
699 54.49382839696272
799 39.65255458368566
899 29.654088206511524
999 22.91191710817654
1099 18.361176643949378
1199 15.286555844126006
1299 13.207151396979782
1399 11.799372745740467
1499 10.845287559695974
1599 10.197987140593028
1699 9.758346013969259
1799 9.459414821586174
1899 9.255930718425928
1999 9.117261244083245
Result: y = 0.0130464524240786 + 0.8449078388334846 x + -0.002250730137058965 x^2 + -0.0916472472301237 x^3

import numpy as np
import math

# 무작위로 입력과 출력 데이터를 생성합니다
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

# 무작위로 가중치를 초기화합니다
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()

learning_rate = 1e-6
for t in range(2000):
    # 순전파 단계: 예측값 y를 계산합니다
    # y = a + b x + c x^2 + d x^3
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # 손실(loss)을 계산하고 출력합니다
    loss = np.square(y_pred - y).sum()
    if t % 100 == 99:
        print(t, loss)

    # 손실에 따른 a, b, c, d의 변화도(gradient)를 계산하고 역전파합니다.
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # 가중치를 갱신합니다.
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

Total running time of the script: ( 0 minutes 0.290 seconds)

Gallery generated by Sphinx-Gallery


이 튜토리얼이 어떠셨나요?

© Copyright 2022, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동