참고
Click here to download the full example code
Using User-Defined Triton Kernels with torch.compile
¶
Author: Oguz Ulgen
User-defined Triton kernels can be used to optimize specific parts of your
model’s computation. These kernels are written in Triton’s language, which is designed
to make it easier to achieve peak hardware performance. By using user-defined Triton
kernels with torch.compile
, you can integrate these optimized computations into
your PyTorch model, potentially achieving significant performance improvements.
This recipes demonstrates how you can use user-defined Triton kernels with torch.compile
.
Prerequisites¶
Before starting this recipe, make sure that you have the following:
Basic understanding of
torch.compile
and Triton. See:PyTorch 2.3 or later
A GPU that supports Triton
import torch
from torch.utils._triton import has_triton
Basic Usage¶
In this example, we will use a simple vector addition kernel from the Triton documentation
with torch.compile
.
For reference, see Triton documentation.
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Traceback (most recent call last):
File "/workspace/tutorials-kr/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py", line 78, in <module>
out = add_fn(x, y)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2268, in RETURN_VALUE
self.output.compile_subgraph(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1001, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1178, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1251, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1232, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1731, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1330, in compile_fx
return aot_autograd(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py", line 58, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 903, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 628, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 443, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 648, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 119, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1257, in fw_compiler_base
return inner_compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 438, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 714, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 1307, in compile_to_fn
return self.compile_to_module().call
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 1254, in compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 2160, in load_by_key_path
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_root/az/cazewxpuh3ewiqgj4dvz5r55i5pccrhx2aygu7qtvczwbcn7ddq3.py", line 105, in <module>
async_compile.wait(globals())
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 2715, in wait
scope[key] = result.result()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 2522, in result
self.future.result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp0is7dynq/main.c', '-O3', '-I/usr/local/lib/python3.10/dist-packages/triton/common/../third_party/cuda/include', '-I/usr/include/python3.10', '-I/tmp/tmp0is7dynq', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp0is7dynq/triton_.cpython-310-x86_64-linux-gnu.so', '-L/lib/x86_64-linux-gnu', '-L/lib/x86_64-linux-gnu']' returned non-zero exit status 1.
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Advanced Usage¶
Triton’s autotune feature is a powerful tool that automatically optimizes the configuration parameters of your Triton kernels. It explores a range of possible configurations and selects the one that delivers the best performance for your specific use case.
When used with torch.compile
, triton.autotune
can help ensure that your PyTorch
model is running as efficiently as possible. Here is an example of using torch.compile
and triton.autotune
.
참고
torch.compile
only supports configs and key arguments to triton.autotune
.
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Composibility and Limitations¶
As of PyTorch 2.3, the support for user-defined Triton kernels in torch.compile
includes dynamic shapes, torch.autograd.Function
, JIT inductor, and AOT inductor.
You can use these features together to build complex, high-performance models.
However, there are certain limitations to be aware of:
Tensor Subclasses: Currently, there is no support for tensor subclasses and other advanced features.
Triton Features: While
triton.heuristics
can be used either standalone or beforetriton.autotune
, it cannot be used after`triton.autotune
. This implies that iftriton.heuristics
andtriton.autotune
are to be used together,triton.heuristics
must be used first.
Conclusion¶
In this recipe, we explored how to utilize user-defined Triton kernels
with torch.compile
. We delved into the basic usage of a simple
vector addition kernel and advanced usage involving Triton’s autotune
feature. We also discussed the composability of user-defined Triton
kernels with other PyTorch features and highlighted some current limitations.
See Also¶
Total running time of the script: ( 0 minutes 0.484 seconds)