Shortcuts

준비 운동: NumPy

\(y=\sin(x)\) 을 예측할 수 있도록, \(-\pi\) 부터 \(pi\) 까지 유클리드 거리(Euclidean distance)를 최소화하도록 3차 다항식을 학습합니다.

이 구현은 NumPy를 사용하여 순전파 단계와 손실(loss), 역전파 단계를 직접 계산합니다.

NumPy 배열은 일반적인 n-차원 배열로, 딥러닝이나 변화도(gradient), 연산 그래프(computational graph)는 알지 못하며 일반적인 수치 연산을 수행합니다.

99 1546.0365145892183
199 1095.8992907857576
299 777.6096767345516
399 552.5352611528601
499 393.3692169177716
599 280.8066427084271
699 201.19883404382796
799 144.89548639133625
899 105.07298180275443
999 76.90616893555887
1099 56.98289647761568
1199 42.89011969702014
1299 32.9212810307606
1399 25.86941791878095
1499 20.880874497663395
1599 17.35185928109817
1699 14.855296144964393
1799 13.089094751222989
1899 11.839566794279794
1999 10.955552494109163
Result: y = 0.04887728059779718 + 0.854335392917212 x + -0.008432144224579812 x^2 + -0.09298823470267067 x^3

import numpy as np
import math

# 무작위로 입력과 출력 데이터를 생성합니다
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

# 무작위로 가중치를 초기화합니다
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()

learning_rate = 1e-6
for t in range(2000):
    # 순전파 단계: 예측값 y를 계산합니다
    # y = a + b x + c x^2 + d x^3
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # 손실(loss)을 계산하고 출력합니다
    loss = np.square(y_pred - y).sum()
    if t % 100 == 99:
        print(t, loss)

    # 손실에 따른 a, b, c, d의 변화도(gradient)를 계산하고 역전파합니다.
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # 가중치를 갱신합니다.
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

Total running time of the script: ( 0 minutes 0.563 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동