참고
Click here to download the full example code
NumPy와 SciPy를 사용한 확장(Extension) 만들기¶
Author: Adam Paszke
Updated by: Adam Dziedzic
번역: Ajin Jeong
이번 튜토리얼에서는 두 가지 작업을 수행할 것입니다:
- 매개 변수가 없는 신경망 계층(layer) 만들기
이는 구현의 일부로 NumPy 를 호출합니다.
- 학습 가능한 가중치가 있는 신경망 계층(layer) 만들기
이는 구현의 일부로 Scipy 를 호출합니다.
import torch
from torch.autograd import Function
매개 변수가 없는(Parameter-less) 예시¶
이 계층(layer)은 특별히 유용하거나 수학적으로 올바른 작업을 수행하지 않습니다.
이름은 대충 BadFFTFunction
으로 지었습니다.
계층(layer) 구현
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
@staticmethod
def forward(ctx, input):
numpy_input = input.detach().numpy()
result = abs(rfft2(numpy_input))
return input.new(result)
@staticmethod
def backward(ctx, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return grad_output.new(result)
# 이 계층에는 매개 변수가 없으므로 ``nn.Module`` 클래스가 아닌 함수로 간단히 선언할 수 있습니다.
def incorrect_fft(input):
return BadFFTFunction.apply(input)
생성된 계층(layer)의 사용 예시:
input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))
print(input)
tensor([[ 0.7493, 5.3761, 7.4498, 5.2065, 8.1245],
[10.7571, 9.3190, 13.9146, 6.2218, 1.9285],
[ 5.8930, 1.9813, 11.8968, 6.0770, 8.5321],
[17.1539, 9.4069, 11.7599, 1.9252, 11.2858],
[11.2505, 5.9365, 3.1825, 6.0164, 3.0514],
[17.1539, 7.1099, 3.6442, 8.3454, 11.2858],
[ 5.8930, 7.0157, 8.8019, 9.0743, 8.5321],
[10.7571, 5.9901, 3.6158, 5.7762, 1.9285]],
grad_fn=<BadFFTFunctionBackward>)
tensor([[ 0.0834, 0.0469, 0.2830, -1.0414, 2.5665, -0.2704, 0.5082, 0.7925],
[-0.3522, 0.4684, 0.2900, -0.9710, -0.6960, 0.5306, -1.6445, 0.8292],
[-0.0250, -1.1805, -1.3117, 0.5479, 0.7102, -1.0263, -0.9084, -0.5743],
[ 0.9343, -1.5087, 1.4279, 1.0486, 0.1963, -1.4240, 0.1408, -0.1087],
[ 0.0217, 1.1997, 0.1217, -1.1010, -2.2295, 0.1683, -0.2580, -0.7808],
[-0.2993, 0.9932, 0.6913, -0.8400, 0.4267, 0.2024, -2.1530, -2.0117],
[ 1.9086, 0.9026, 1.6754, 0.3198, 0.5785, 1.1717, 1.6946, 0.6565],
[ 1.1715, -0.9260, -1.4667, -0.3593, 0.1675, 0.1235, -0.5661, -0.3147]],
requires_grad=True)
매개 변수가 있는(Parameterized) 예시¶
딥러닝 문헌에서 이 계층(layer)의 실제 연산은 상호 상관(cross-correlation)이지만 합성곱(convolution)이라고 헷갈리게 부르고 있습니다. (합성곱은 필터를 뒤집어서 연산을 하는 반면, 상호 상관은 그렇지 않은 차이가 있습니다)
학습 가능한 가중치를 가는 필터(커널)를 갖는 상호 상관 계층을 구현해보겠습니다.
역전파 단계(backward pass)에서는 입력에 대한 기울기(gradient)와 필터에 대한 기울기를 계산합니다.
from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
class ScipyConv2dFunction(Function):
@staticmethod
def forward(ctx, input, filter, bias):
# 분리(detach)하여 NumPy로 변환(cast)할 수 있습니다.
input, filter, bias = input.detach(), filter.detach(), bias.detach()
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)
return torch.as_tensor(result, dtype=input.dtype)
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.detach()
input, filter, bias = ctx.saved_tensors
grad_output = grad_output.numpy()
grad_bias = np.sum(grad_output, keepdims=True)
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
# 윗줄은 다음과 같이 표현할 수도 있습니다:
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
class ScipyConv2d(Module):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))
def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
사용 예시:
module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)
Filter and bias: [Parameter containing:
tensor([[-0.0685, 0.5146, -0.9103],
[-0.9190, 0.7659, 1.4033],
[-1.5628, -0.4356, 0.1513]], requires_grad=True), Parameter containing:
tensor([[0.0500]], requires_grad=True)]
Output from the convolution: tensor([[ 1.0434, 0.5532, 1.4637, 2.4000, -2.9941, -3.0063, -0.8389, -0.2737],
[-2.0167, 1.2345, -3.4750, 0.9726, 3.4054, -0.9997, 1.6345, -5.7391],
[ 0.8466, 1.3429, -4.8645, -1.8883, 2.5089, 0.6507, 5.4805, -0.4945],
[-0.6900, -0.4684, -1.6632, -1.6022, -1.6148, -4.8085, -1.3337, 3.9142],
[ 3.7322, -0.9291, 1.1793, 4.3954, 1.6756, -0.4551, 4.0487, -2.3517],
[ 1.3401, -0.4182, -2.7344, 2.6672, -1.1287, -1.3467, 2.7797, 10.2257],
[-1.6213, -0.9084, -1.0380, -0.1645, -2.4001, 0.8624, 0.8780, 5.1931],
[ 1.2913, 1.1509, -4.2846, 5.4657, -2.2595, 0.5128, -0.7484, -1.2321]],
grad_fn=<ScipyConv2dFunctionBackward>)
Gradient for the input map: tensor([[-0.0479, 0.5676, -2.0581, 1.6713, 2.2732, -1.0390, 0.6838, -0.7023,
1.2946, -0.6725],
[-0.7185, 3.9055, -0.6547, -6.2065, -1.9312, 1.2286, -0.0152, -2.3588,
0.4810, 1.5156],
[-2.0281, 5.0570, 6.3651, 0.0545, -0.4638, -0.4837, 4.1343, -3.3259,
-2.6426, -0.6377],
[-0.7372, 0.3809, -3.2023, 0.7439, -2.4857, 0.5112, 6.0474, 3.1520,
-0.3647, -0.7803],
[ 1.2595, 4.1878, -1.7010, -1.6706, -2.2941, -5.0591, -0.4691, 0.6966,
0.6100, 2.1561],
[-1.6727, 2.7531, 1.2086, 0.5137, -0.2360, 1.9270, 5.1387, -3.4446,
-2.6745, -0.9914],
[-2.4838, -1.5068, 1.6128, -0.5395, -1.3657, -1.3266, 3.4149, 0.5285,
-0.9051, 0.5061],
[-1.9095, -2.3272, 3.8304, -0.7444, 1.8315, -1.8599, 3.1576, 4.0086,
-2.1298, -2.1147],
[ 0.3715, -0.0954, 0.2900, -2.3744, -0.1899, -2.7285, 0.9018, 2.3479,
-1.1129, -0.5725],
[ 2.5708, 2.5453, 1.7140, 0.2467, 0.9614, 0.7750, 1.6201, 0.7669,
-0.0519, -0.0356]])
기울기(gradient) 확인:
from torch.autograd.gradcheck import gradcheck
moduleConv = ScipyConv2d(3, 3)
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)
Are the gradients correct: True
Total running time of the script: ( 0 minutes 0.164 seconds)