• Tutorials >
  • 사용자 정의 Triton 커널을 ``torch.compile``과 함께 사용하기
Shortcuts

사용자 정의 Triton 커널을 ``torch.compile``과 함께 사용하기

저자: Oguz Ulgen 번역: 구경선, 이채운

사용자 정의 Triton 커널을 사용하면 모델의 특정 부분의 계산을 최적화할 수 있습니다. 이 커널들은 Triton의 언어로 작성된 것으로 설계되었습니다. 사용자 정의 Triton을 사용하여 하드웨어 성능을 최고로 향상시킵니다. ``torch.compile``를 사용하는 커널은 이러한 최적화된 계산을 통합할 수 있습니다. PyTorch 모델을 통해 상당한 성능 향상을 실현할 수 있습니다.

이 레시피는 사용자 정의 Triton 커널을 ``torch.compile``과 함께 사용할 수 있는 방법을 보여줍니다.

전제조건

이 레시피를 시작하기 전에 다음이 있는지 확인합니다:

import torch
from torch.utils._triton import has_triton

기본 사용법

이 예에서는 Triton 문서의 간단한 벡터 덧셈 커널을 사용합니다. ``torch.compile``과 함께. 참고, Triton 문서를 참고하세요.

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Y:      tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0')
is equal to
tensor([ 0.3332,  2.0532, -0.8895,  1.6057], device='cuda:0')

고급 사용법

Triton의 자동 튜닝 기능은 Triton 커널의 구성 매개변수를 자동으로 최적화해주는 강력한 도구입니다. 다양한 설정을 검토하여 특정 사용 사례에 최적의 성능을 제공하는 구성을 선택합니다.

``torch.compile``과 함께 사용할 경우 ``triton.autotune``을 사용하면 PyTorch 모델을 최대한 효율적으로 실행할 수 있습니다. 아래는 ``torch.compile``과 ``triton.autotune``을 사용하는 예제입니다.

참고

``torch.compile``은 ``triton.autotune``에 대한 configs와 key 인수만 지원합니다.

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-0.5187,  1.2268,  0.6255, -0.9117], device='cuda:0')
Y:      tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')

호환성과 제한사항

PyTorch 2.3 버전 기준으로, torch.compile``의 사용자 정의 Triton 커널에는 동적 모양 ``torch.autograd.Function, JIT inductor, AOT inductor가 지원됩니다. 이 기능들을 조합하여 복잡하고 고성능인 모델을 구축할 수 있습니다.

그러나 알아두어야 할 몇 가지 제한 사항이 있습니다.

  • Tensor Subclasses: 현재로서는 Tensor 하위 클래스 및 기타 고급 기능은 지원되지 않습니다.

  • Triton Features: triton.heuristics``는 단독으로 사용하거나 ``triton.autotune 앞에서

사용할 수 있지만, triton.autotune 뒤에서는 사용할 수 없습니다. 따라서 ``triton.heuristics``와 ``triton.autotune``을 함께 사용하려면 ``triton.heuristics``를 먼저 사용해야 합니다.

결론

이 레시피에서는 사용자 정의 Triton 커널을 ``torch.compile``로 활용하는 방법을 알아보았습니다. 간단한 벡터 덧셈 커널의 기본 사용법과 Triton의 자동 튜닝 기능을 포함한 고급 사용법에 대해 다뤘습니다. 또한 사용자 정의 Triton 커널과 다른 Pytorch 기능의 조합 가능성에 대해 논의하고 현재의 몇 가지 제한 사항을 강조했습니다.


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동