(beta) Compiling the optimizer with torch.compile

Author: Michael Lazos

The optimizer is a key algorithm for training any deep learning model. Since it is responsible for updating every model parameter, it can often become the bottleneck in training performance for large models. In this recipe, we will apply torch.compile to the optimizer to observe the GPU performance improvement.


This tutorial requires PyTorch 2.2.0 or later.

Model Setup

For this example, we’ll use a simple sequence of linear layers. Since we are only benchmarking the optimizer, the choice of model doesn’t matter because optimizer performance is a function of the number of parameters.

Depending on what machine you are using, your exact results may vary.

import torch

model = torch.nn.Sequential(
    *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
input = torch.rand(1024, device="cuda")
output = model(input)

Setting up and running the optimizer benchmark

In this example, we’ll use the Adam optimizer and create a helper function to wrap the step() in torch.compile().


torch.compile is only supported on cuda devices with compute capability >= 7.0

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys

opt = torch.optim.Adam(model.parameters(), lr=0.01)

def fn():

# Let's define a helpful benchmarking function:
import torch.utils.benchmark as benchmark

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    return t0.blocked_autorange().mean * 1e6

# Warmup runs to compile the function
for _ in range(5):

eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
compiled_runtime = benchmark_torch_function_in_microseconds(fn)

assert eager_runtime > compiled_runtime

print(f"eager runtime: {eager_runtime}us")
print(f"compiled runtime: {compiled_runtime}us")

Sample Results:

  • Eager runtime: 747.2437149845064us

  • Compiled runtime: 392.07384741178us

See Also

  • For an in-depth technical overview, see

Compiling the optimizer with PT2

더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!

이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동