참고
Click here to download the full example code
(Beta) Scaled Dot Product Attention (SDPA)로 고성능 트랜스포머(Transformers) 구현하기¶
저자: Driss Guessous 번역: 이강희
요약¶
이 튜토리얼에서, 트랜스포머(Transformer) 아키텍처 구현에 도움이 되는 새로운
torch.nn.functional
모듈의 함수를 소개합니다. 이 함수의 이름은 torch.nn.functional.scaled_dot_product_attention
입니다. 함수에 대한 자세한 설명은 PyTorch 문서
를 참고하세요. 이 함수는 이미 torch.nn.MultiheadAttention
과 torch.nn.TransformerEncoderLayer
에서 사용되고 있습니다.
개요¶
고수준에서, 이 PyTorch 함수는 쿼리(query), 키(key), 값(value) 사이의 scaled dot product attention (SDPA)을 계산합니다. 이 함수의 정의는 Attention is all you need 논문에서 찾을 수 있습니다. 이 함수는 기존 함수를 사용하여 PyTorch로 작성할 수 있지만, 퓨즈드(fused) 구현은 단순한 구현보다 큰 성능 이점을 제공할 수 있습니다.
퓨즈드 구현¶
이 함수는 CUDA tensor 입력을 다음 중 하나의 구현을 사용합니다.
구현:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
A PyTorch implementation defined in C++
참고
이 튜토리얼은 PyTorch 버전 2.0.0 이상이 필요합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# 사용 예시:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[ 0.2945, -0.4806, -1.3947, 0.0966, 0.4128, 1.0451, 0.5873,
0.0469],
[ 0.3673, -0.9850, -2.4275, -0.1616, 0.1597, 0.7109, 1.0664,
-0.2402],
[ 0.1968, 0.0102, -0.5221, 0.5455, 0.4381, 1.2384, 0.4569,
0.2166]],
[[ 0.3063, 0.7608, -0.3551, 0.4263, -1.3290, -0.0218, 1.1308,
-0.0193],
[ 0.2526, 1.1456, -0.2168, 0.7778, -1.3467, 0.0417, 0.7623,
0.5591],
[ 0.3538, 1.5476, -0.1805, 0.4686, -1.4972, -0.6577, 1.0304,
1.5759]]], device='cuda:0')
명시적 Dispatcher 제어¶
이 함수는 암시적으로 세 가지 구현 중 하나를 사용합니다. 하지만 컨텍스트 매니저를 사용하면 명시적으로 어떤 구현을 사용할 지 제어할 수 있습니다. 컨텍스트 매니저를 통해 특정 구현을 명시적으로 비활성화 할 수 있습니다. 특정 입력에 대한 가장 빠른 구현을 찾고자 한다면, 컨텍스트 매니저로 모든 구현의 성능을 측정해볼 수 있습니다.
# 벤치마크 함수를 정의합니다
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# 입력의 하이퍼파라미터를 정의합니다
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# 세 가지 구현의 속도를 측정합니다
from torch.backends.cuda import sdp_kernel, SDPBackend
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 3795.025 microseconds
The math implementation runs in 13696.378 microseconds
The flash attention implementation runs in 3849.600 microseconds
The memory efficient implementation runs in 4189.878 microseconds
하드웨어 의존성¶
위 셀을 어떤 머신에서 실행했는지와 사용 가능한 하드웨어에 따라 결과가 다를 수 있습니다. - GPU가 없고 CPU에서 실행 중이라면 컨텍스트 매니저는 효과가 없고 세 가지 실행 모두 유사한 시간을 반환할 것입니다. - 그래픽 카드가 지원하는 컴퓨팅 능력에 따라 flash attention 또는 memory efficient 구현이 동작하지 않을 수 있습니다.
Causal Self Attention¶
아래는 multi-head causal self attention 블록의 구현 예시입니다. Andrej Karpathy NanoGPT 저장소를 참고했습니다.
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
(c_proj): Linear(in_features=512, out_features=512, bias=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
NestedTensor
및 Dense tensor 지원¶
SDPA는 NestedTensor
와 Dense tensor 입력을 모두 지원합니다.
NestedTensors
는 입력이 가변 길이 시퀀스로 구성된 배치인 경우에
배치 내 시퀀스의 최대 길이에 맞춰 각 시퀀스를 패딩할 필요가 없습니다. NestedTensors
에 대한 자세한 내용은
torch.nested 와 NestedTensors 튜토리얼 을 참고하세요.
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# 현재 퓨즈드 구현은 ``NestedTensor`` 로 학습하는 것을 지원하지 않습니다.
model.eval()
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
/workspace/tutorials-kr/intermediate_source/scaled_dot_product_attention_tutorial.py:222: UserWarning:
The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.)
/workspace/tutorials-kr/intermediate_source/scaled_dot_product_attention_tutorial.py:169: UserWarning:
Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:545.)
/workspace/tutorials-kr/intermediate_source/scaled_dot_product_attention_tutorial.py:169: UserWarning:
Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
/workspace/tutorials-kr/intermediate_source/scaled_dot_product_attention_tutorial.py:169: UserWarning:
Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
/workspace/tutorials-kr/intermediate_source/scaled_dot_product_attention_tutorial.py:169: UserWarning:
We are not enabling nested Tensors for Flash Attention because of cuda memory errors. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:155.)
FlashAttention is not supported. See warnings for reasons.
torch.compile
과 함께 SDPA 사용하기¶
PyTorch 2.0 릴리즈와 함께 torch.compile()
라는 새로운 기능이 추가되었는데,
이는 eager mode보다 상당한 성능 향상을 제공할 수 있습니다.
Scaled dot product attention은 torch.compile()
로 완전히 구성할 수 있습니다.
이를 확인하기 위해 torch.compile()
을 통해 CausalSelfAttention
모듈을 컴파일하고
결과적으로 얻어지는 성능 향상을 알아봅시다.
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 366.749 microseconds
The compiled module runs in 365.118 microseconds
정확한 실행 시간은 환경에 따라 다르지만, 다음은 저자의 결과입니다. 컴파일 되지 않은 모듈은 실행에 166.616ms 가 소요되었습니다. 컴파일 된 모듈은 실행에 166.726ms 가 소요되었습니다. 이는 우리의 예상과는 다릅니다. 좀 더 자세히 알아봅시다. PyTorch는 코드의 성능 특성을 점검할 수 있는 놀라운 내장(built-in) 프로파일러를 제공합니다.
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# 더 많은 정보를 얻기 위해 추적(trace)를 내보내고 ``chrome://tracing``을 사용하여 결과를 확인해보세요.
# ::
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Non-Compilied Causal Attention 14.75% 1.433ms 67.32% 6.541ms 6.541ms 0.000us 0.00% 9.169ms 9.169ms 1
aten::matmul 2.12% 206.000us 20.59% 2.001ms 40.020us 0.000us 0.00% 4.945ms 98.900us 50
aten::mm 13.33% 1.295ms 16.68% 1.621ms 32.420us 4.945ms 53.93% 4.945ms 98.900us 50
aten::linear 2.72% 264.000us 23.24% 2.258ms 45.160us 0.000us 0.00% 4.411ms 88.220us 50
aten::scaled_dot_product_attention 1.28% 124.000us 21.87% 2.125ms 85.000us 0.000us 0.00% 4.224ms 168.960us 25
aten::_scaled_dot_product_flash_attention 4.28% 416.000us 20.59% 2.001ms 80.040us 0.000us 0.00% 4.224ms 168.960us 25
aten::_flash_attention_forward 2.69% 261.000us 6.17% 599.000us 23.960us 4.124ms 44.98% 4.124ms 164.960us 25
void fmha_fwd_loop_kernel<FMHA_kernel_traits<256, 64... 0.00% 0.000us 0.00% 0.000us 0.000us 4.124ms 44.98% 4.124ms 164.960us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 3.514ms 38.32% 3.514ms 140.560us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 1.431ms 15.61% 1.431ms 57.240us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 9.716ms
Self CUDA time total: 9.169ms
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Compiled Causal Attention 13.97% 1.521ms 82.24% 8.957ms 8.957ms 0.000us 0.00% 8.617ms 8.617ms 1
CompiledFunction 38.85% 4.231ms 67.80% 7.384ms 295.360us 0.000us 0.00% 8.617ms 344.680us 25
aten::mm 7.40% 806.000us 10.88% 1.185ms 23.700us 4.478ms 51.97% 4.478ms 89.560us 50
aten::_scaled_dot_product_flash_attention 2.41% 262.000us 14.63% 1.593ms 63.720us 0.000us 0.00% 4.139ms 165.560us 25
aten::_flash_attention_forward 2.06% 224.000us 4.44% 484.000us 19.360us 4.039ms 46.87% 4.039ms 161.560us 25
void fmha_fwd_loop_kernel<FMHA_kernel_traits<256, 64... 0.00% 0.000us 0.00% 0.000us 0.000us 4.039ms 46.87% 4.039ms 161.560us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 3.194ms 37.07% 3.194ms 127.760us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 1.284ms 14.90% 1.284ms 51.360us 25
aten::arange 3.07% 334.000us 10.28% 1.120ms 11.200us 100.000us 1.16% 184.000us 1.840us 100
void (anonymous namespace)::elementwise_kernel_with_... 0.00% 0.000us 0.00% 0.000us 0.000us 100.000us 1.16% 100.000us 2.000us 50
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 10.891ms
Self CUDA time total: 8.617ms
이전 코드 조각(snippet)은 컴파일 된 모듈과 컴파일되지 않은 모듈 모두에 대해
가장 많은 GPU 실행 시간을 차지한 상위 10개의 PyTorch 함수에 대한 보고서를 생성합니다.
분석 결과, 두 모듈 모두 GPU에서 소요된 시간의 대부분이
동일한 함수들에 집중되어 있음을 보여줍니다.
PyTorch가 프레임워크 오버헤드를 제거하는 데 매우 탁월한 torch.compile
를
제공하기 때문입니다. CausalSelfAttention
같은 경우처럼 크고, 효율적인 CUDA 커널을
사용하는 모델에서 PyTorch 오버헤드는 작아질 것입니다.
사실, 모듈은 보통 CausalSelfAttention
블럭 하나만으로 구성되지 않습니다.
Andrej Karpathy NanoGPT 저장소에서 실험한 경우,
모듈을 컴파일 하는 것은 학습의 각 단계별 소요 시간을 6090.49ms
에서 3273.17ms
로
줄일 수 있었습니다. 이 실험은 NanoGPT 저장소의 ae3a8d5
커밋에서 Shakespeare
데이터셋을 사용하여 진행되었습니다.
결론¶
이 튜토리얼에서, torch.nn.functional.scaled_dot_product_attention
의 기본적인
사용법을 살펴봤습니다. sdp_kernel
컨텍스트 매니저로 GPU가 특정 구현을
사용하도록 할 수 있다는 것을 보았습니다. 또한, 간단한 NestedTensor
에서 작동하고
컴파일 가능한 CausalSelfAttention
모듈을 만들었습니다.
이 과정에서 프로파일링 도구를 사용하여 유저가 정의한 모듈의 성능 특성을 어떻게
확인할 수 있는지도 살펴봤습니다.
Total running time of the script: ( 0 minutes 5.212 seconds)