• Tutorials >
  • (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
Shortcuts

(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)

Author: Driss Guessous

Summary

In this tutorial, we want to highlight a new torch.nn.functional function that can be helpful for implementing transformer architectures. The function is named torch.nn.functional.scaled_dot_product_attention. For detailed description of the function, see the PyTorch documentation. This function has already been incorporated into torch.nn.MultiheadAttention and torch.nn.TransformerEncoderLayer.

Overview

At a high level, this PyTorch function calculates the scaled dot product attention (SDPA) between query, key, and value according to the definition found in the paper Attention is all you need. While this function can be written in PyTorch using existing functions, a fused implementation can provide large performance benefits over a naive implementation.

Fused implementations

For CUDA tensor inputs, the function will dispatch into one of the following implementations:

참고

This tutorial requires PyTorch 2.0.0 or later.

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[ 0.2945, -0.4806, -1.3947,  0.0966,  0.4128,  1.0451,  0.5873,
           0.0469],
         [ 0.3673, -0.9850, -2.4275, -0.1616,  0.1597,  0.7109,  1.0664,
          -0.2402],
         [ 0.1968,  0.0102, -0.5221,  0.5455,  0.4381,  1.2384,  0.4569,
           0.2166]],

        [[ 0.3063,  0.7608, -0.3551,  0.4263, -1.3290, -0.0218,  1.1308,
          -0.0193],
         [ 0.2526,  1.1456, -0.2168,  0.7778, -1.3467,  0.0417,  0.7623,
           0.5591],
         [ 0.3538,  1.5476, -0.1805,  0.4686, -1.4972, -0.6577,  1.0304,
           1.5759]]], device='cuda:0')

Explicit Dispatcher Control

While the function will implicitly dispatch to one of the three implementations, the user can also explicitly control the dispatch via the use of a context manager. This context manager allows users to explicitly disable certain implementations. If a user wants to ensure the function is indeed using the fastest implementation for their specific inputs, the context manager can be used to sweep through measuring performance.

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import sdp_kernel, SDPBackend

# Helpful arguments mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")


with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 45113.683 microseconds
The math implementation runs in 56292.578 microseconds
FlashAttention is not supported. See warnings for reasons.
The memory efficient implementation runs in 45090.366 microseconds

Hardware dependence

Depending on what machine you ran the above cell on and what hardware is available, your results might be different. - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all three runs should return similar timings. - Depending on what compute capability your graphics card supports flash attention or memory efficient might have failed.

Causal Self Attention

Below is an example implementation of a multi-headed causal self attention block inspired by Andrej Karpathy NanoGPT repository.

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor and Dense tensor support

SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see torch.nested and NestedTensors Tutorial.

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
FlashAttention is not supported. See warnings for reasons.

Using SDPA with torch.compile

With the release of PyTorch 2.0, a new feature called torch.compile() has been introduced, which can provide significant performance improvements over eager mode. Scaled dot product attention is fully composable with torch.compile(). To demonstrate this, let’s compile the CausalSelfAttention module using torch.compile() and observe the resulting performance improvements.

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  3189.925 microseconds
The compiled module runs in  3190.352 microseconds

The exact execution time is dependent on machine, however the results for mine: The non compiled module runs in 166.616 microseconds The compiled module runs in 166.726 microseconds That is not what we were expecting. Let’s dig a little deeper. PyTorch comes with an amazing built-in profiler that you can use to inspect the performance characteristics of your code.

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
# ::
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention         2.84%       2.598ms        21.24%      19.409ms      19.409ms       0.000us         0.00%      79.885ms      79.885ms             1
                                           aten::matmul         0.36%     326.000us        14.78%      13.507ms     270.140us       0.000us         0.00%      54.317ms       1.086ms            50
                                               aten::mm         1.95%       1.786ms        13.98%      12.771ms     255.420us      54.317ms        67.99%      54.317ms       1.086ms            50
                                           aten::linear         0.23%     213.000us        15.42%      14.090ms     281.800us       0.000us         0.00%      52.773ms       1.055ms            50
                                 hgemm_128x128x8_NT_vec         0.00%       0.000us         0.00%       0.000us       0.000us      38.596ms        48.31%      38.596ms       1.544ms            25
          aten::_scaled_dot_product_efficient_attention         0.50%     459.000us         1.49%       1.364ms      54.560us       0.000us         0.00%      25.568ms       1.023ms            25
                     aten::_efficient_attention_forward         0.27%     245.000us         0.87%     792.000us      31.680us      25.568ms        32.01%      25.568ms       1.023ms            25
void attention_kernel_batched<AttentionKernel<cutlas...         0.00%       0.000us         0.00%       0.000us       0.000us      25.568ms        32.01%      25.568ms       1.023ms            25
                     aten::scaled_dot_product_attention         0.36%     331.000us         1.74%       1.586ms      63.440us       0.000us         0.00%      23.521ms     940.840us            25
                                  hgemm_32x32x32_NT_vec         0.00%       0.000us         0.00%       0.000us       0.000us      15.721ms        19.68%      15.721ms     628.840us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 91.359ms
Self CUDA time total: 79.885ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         4.05%       3.269ms        15.91%      12.854ms      12.854ms       0.000us         0.00%      79.842ms      79.842ms             1
                                       CompiledFunction         6.74%       5.446ms        11.67%       9.429ms     377.160us       0.000us         0.00%      79.842ms       3.194ms            25
                                               aten::mm         0.93%     748.000us         1.70%       1.375ms      27.500us      54.306ms        68.02%      54.306ms       1.086ms            50
                                 hgemm_128x128x8_NT_vec         0.00%       0.000us         0.00%       0.000us       0.000us      38.591ms        48.33%      38.591ms       1.544ms            25
          aten::_scaled_dot_product_efficient_attention         0.38%     303.000us         1.65%       1.332ms      53.280us       0.000us         0.00%      25.536ms       1.021ms            25
                     aten::_efficient_attention_forward         0.37%     298.000us         1.08%     875.000us      35.000us      25.536ms        31.98%      25.536ms       1.021ms            25
void attention_kernel_batched<AttentionKernel<cutlas...         0.00%       0.000us         0.00%       0.000us       0.000us      25.536ms        31.98%      25.536ms       1.021ms            25
                                  hgemm_32x32x32_NT_vec         0.00%       0.000us         0.00%       0.000us       0.000us      15.715ms        19.68%      15.715ms     628.600us            25
                                    aten::empty_strided         0.49%     397.000us         0.49%     397.000us       7.940us       0.000us         0.00%       0.000us       0.000us            50
                                       aten::as_strided         0.95%     767.000us         0.95%     767.000us       1.615us       0.000us         0.00%       0.000us       0.000us           475
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 80.792ms
Self CUDA time total: 79.842ms

The previous code snippet generates a report of the top 10 PyTorch functions that consumed the most GPU execution time, for both the compiled and non-compiled module. The analysis reveals that the majority of time spent on the GPU is concentrated on the same set of functions for both modules. The reason for this here is that torch.compile is very good at removing the framework overhead associated with PyTorch. If your model is launching large, efficient CUDA kernels, which in this case CausaulSelfAttention is, then the overhead of PyTorch can be hidden.

In reality, your module does not normally consist of a singular CausalSelfAttention block. When experimenting with Andrej Karpathy NanoGPT repository, compiling the module took the time per train step from: 6090.49ms to 3273.17ms! This was done on commit: ae3a8d5 of NanoGPT training on the Shakespeare dataset.

Conclusion

In this tutorial, we have demonstrated the basic usage of torch.nn.functional.scaled_dot_product_attention. We have shown how the sdp_kernel context manager can be used to assert a certain implementation is used on GPU. As well, we built a simple CausalSelfAttention module that works with NestedTensor and is torch compilable. In the process we have shown how to the profiling tools can be used to explore the performance characteristics of a user defined module.

Total running time of the script: ( 0 minutes 5.102 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동