• Tutorials >
  • 모듈 매개변수 초기화 건너뛰기
Shortcuts

모듈 매개변수 초기화 건너뛰기

소개

모듈이 생성될 때, 모듈 유형과 관련된 기본 초기화 방법에 따라 학습 가능한 매개변수가 초기화됩니다. 예를 들어, torch.nn.Linear 모듈의 weight 매개변수는 uniform(-1/sqrt(in_features), 1/sqrt(in_features)) 분포로 초기화됩니다. 기존에는 다른 초기화 방법이 필요한 경우 모듈 인스턴스화 후 매개변수를 재초기화해야 했습니다.

from torch import nn

# 기본 분포로 가중치를 초기화합니다: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)

# 다른 분포로 가중치를 재초기화합니다.
nn.init.orthogonal_(m.weight)

이 경우 구성 중 수행되는 초기화는 계산 낭비이며, weight 매개변수가 크면 사소한 문제가 아닐 수 있습니다.

초기화 건너뛰기

모듈 구성 중 매개변수 초기화를 건너뛰게 되어 낭비되는 계산을 피할 수 있습니다. torch.nn.utils.skip_init() 함수를 사용하면 쉽게 건너뛰기가 가능합니다.

from torch import nn
from torch.nn.utils import skip_init

m = skip_init(nn.Linear, 10, 5)

# 예제 : 기본 이외의 매개변수 초기화를 수정하여 실행합니다.
nn.init.orthogonal_(m.weight)

아래 초기화 건너뛰기를 위한 모듈 업데이트 섹션에 설명된 조건을 충족하는 모듈에 적용할 수 있습니다. torch.nn 에 있는 모든 모듈은 조건을 충족하기 때문에 초기화 건너뛰기를 지원하고 있습니다.

초기화 건너뛰기를 위한 모듈 업데이트

torch.nn.utils.skip_init() 의 구현(참고 구현 세부 사항) 방법에 따라, 모듈이 함수와 호환되기 위한 두 가지 요구사항이 있습니다. 다음의 요구사항을 이행하면 커스텀 모듈의 매개변수 초기화 건너뛰기 기능을 선택할 수 있습니다.

1. 모듈을 생성할 때 매개변수와 버퍼로 전달되는 모듈의 생성자 내 device 키워드 인자(keyword argument)를 사용해야 합니다.

2. 모듈은 초기화를 제외하고 모듈의 생성자 내 매개변수 또는 버퍼 계산을 수행하지 않아야 합니다 (즉, `torch.nn.init`의 함수).

다음은 device 키워드 인자가 생성된 파라미터, 버퍼, 서브모듈로 따라 전달되기 위한 모듈 업데이트를 보여주는 예시입니다.

import torch
from torch import nn

class MyModule(torch.nn.Module):
  def __init__(self, foo, bar, device=None):
    super().__init__()

    # ==== 사례 1: 모듈 매개변수를 직접 생성합니다. ====
    # 생성한 매개변수에 장치(device)를 전달합니다.
    self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
    self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))

    # meta 장치 지원을 확실히 하기 위해 모듈의 생성자 내 매개변수에
    # torch.nn.init의 ops 외에는 사용하지 마십시오.
    with torch.no_grad():
        nn.init.kaiming_uniform_(self.param1)
        nn.init.uniform_(self.param2)


    # ==== 사례 2: 모듈의 서브 모듈을 생성합니다. ====
    # 모든 서브 모듈이 해당 사항을 지원해야 하기 때문에 장치를 재귀적으로 전달합니다.
    # 이는 torch.nn이 제공하는 모듈들의 경우에 해당합니다.
    self.fc = nn.Linear(bar, 5, device=device)

    # 컨테이너에도 동일하게 적용합니다.
    self.linears = nn.Sequential(
        nn.Linear(5, 5, device=device),
        nn.Linear(5, 1, device=device)
    )


    # ==== 사례 3: 모듈의 버퍼를 생성합니다. ====
    # 버퍼 tensor 생성하는 동안 장치를 전달합니다.
    self.register_buffer('some_buffer', torch.ones(7, device=device))

...

구현 세부 사항

내부적으로 torch.nn.utils.skip_init() 함수는 2단계 패턴으로 구현됩니다.

# 1. meta 장치에서 모듈을 초기화합니다; 모든 torch.nn.init ops는
# meta 장치에서 no-op 동작을 합니다.
m = nn.Linear(10, 5, device='meta')

# 2. 초기화되지 않은(빈) 형태의 모듈을 CPU 장치에 구현합니다.
# 결과는 초기화되지 않은 매개 변수를 가진 모듈 인스턴스입니다.
m.to_empty(device='cpu')

모듈은 《meta》 장치로 인스턴스화하여 동작합니다. tensor shape 정보를 가지고 있지만 저장 공간은 할당하지 않습니다. torch.nn.init ops는 meta 장치를 위해 특별히 구현되어 있고 no-op 동작을 합니다. 이에 따라 매개변수 초기화 로직에서 본질적으로 건너뛰게 됩니다.

초기화 건너뛰기를 위한 모듈 업데이트 에 설명된 대로 이 패턴은 모듈 구성 중 device 키워드 인자를 적절히 지원하는 모듈에서만 작동합니다.


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동