-
Notifications
You must be signed in to change notification settings - Fork 23k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[custom_op] triton_op API V0 (#130637)
This is the initial version of an API to create custom operators whose implementations are backed by triton kernels. While user-defined triton kernels work out-of-the-box with triton kernels, you may wish to construct a custom operator if you need to compose with other PyTorch subsystems, like Tensor subclasses or vmap. I'm hoping to get design feedback on this and ship it so that we can begin experimenting with customers. Test Plan: - new tests Pull Request resolved: #130637 Approved by: https://github.com/albanD
- Loading branch information
1 parent
6beec34
commit ee039c0
Showing
4 changed files
with
240 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from typing import Callable, Iterable, Optional, Union | ||
|
||
from .custom_ops import custom_op | ||
|
||
|
||
def triton_op( | ||
name: str, | ||
fn: Optional[Callable] = None, | ||
/, | ||
*, | ||
mutates_args: Union[str, Iterable[str]], | ||
schema: Optional[str] = None, | ||
) -> Callable: | ||
"""Create a custom operator whose implementation is backed by 1+ triton kernels. | ||
Use this instead of :func:`torch.library.custom_op` when the implementation | ||
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats | ||
custom operators as opaque (:func:`torch.compile` and | ||
:func:`torch.export.export` will never trace into them), but ``triton_op`` | ||
makes the implementation visible to these subsystems, allowing them | ||
to optimize the triton kernel(s). | ||
Note that ``fn`` must only consist of calls to PyTorch-understood | ||
operators and triton kernels. Any triton kernels called inside ``fn`` | ||
must be wrapped in a call to :func:`torch._library.capture_triton``. | ||
Args: | ||
name (str): A name for the custom op that looks like "{namespace}::{name}", | ||
e.g. "mylib::my_linear". The name is used as the op's stable identifier | ||
in PyTorch subsystems (e.g. torch.export, FX graphs). | ||
To avoid name collisions, please use your project name as the namespace; | ||
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. | ||
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. | ||
This MUST be accurate, otherwise, the behavior is undefined. If "unknown", | ||
it pessimistically assumes that all inputs to the operator are being mutated. | ||
schema (None | str): A schema string for the operator. If None | ||
(recommended) we'll infer a schema for the operator from its type | ||
annotations. We recommend letting us infer a schema unless you | ||
have a specific reason not to. | ||
Example: "(Tensor x, int y) -> (Tensor, Tensor)". | ||
Example:: | ||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | ||
>>> import torch | ||
>>> from torch._library import triton_op, capture_triton | ||
>>> | ||
>>> import triton | ||
>>> from triton import language as tl | ||
>>> | ||
>>> @triton.jit | ||
>>> def add_kernel( | ||
>>> in_ptr0, | ||
>>> in_ptr1, | ||
>>> out_ptr, | ||
>>> n_elements, | ||
>>> BLOCK_SIZE: "tl.constexpr", | ||
>>> ): | ||
>>> pid = tl.program_id(axis=0) | ||
>>> block_start = pid * BLOCK_SIZE | ||
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
>>> mask = offsets < n_elements | ||
>>> x = tl.load(in_ptr0 + offsets, mask=mask) | ||
>>> y = tl.load(in_ptr1 + offsets, mask=mask) | ||
>>> output = x + y | ||
>>> tl.store(out_ptr + offsets, output, mask=mask) | ||
>>> | ||
>>> @triton_op("mylib::add", mutates_args={}) | ||
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
>>> output = torch.empty_like(x) | ||
>>> n_elements = output.numel() | ||
>>> | ||
>>> def grid(meta): | ||
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
>>> | ||
>>> # NB: we need to wrap the triton kernel in a call to capture_triton | ||
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) | ||
>>> return output | ||
>>> | ||
>>> @torch.compile | ||
>>> def f(x, y): | ||
>>> return add(x, y) | ||
>>> | ||
>>> x = torch.randn(3, device="cuda") | ||
>>> y = torch.randn(3, device="cuda") | ||
>>> | ||
>>> z = f(x, y) | ||
>>> assert torch.allclose(z, x + y) | ||
""" | ||
|
||
def dec(fn: Callable) -> Callable: | ||
result = custom_op(name, fn, mutates_args=mutates_args) | ||
from .._subclasses.functional_tensor import FunctionalTensorMode | ||
|
||
# We require that the user pass us a function that is make_fx traceable, | ||
# so we can just register it as the Fake/meta kernel. | ||
result.register_fake(fn) | ||
|
||
# We decompose the operator when FunctionalTensorMode is active. | ||
# The goal is to decompose the operator in AOTDispatcher. | ||
# - With torch.compile, this means that the backend (usually Inductor) | ||
# can see a call to the triton kernel(s) and so it can directly optimize | ||
# them by inlining them into the lowering process. | ||
# - With post-dispatch torch.export, this means that there will | ||
# be a call(s) to the triton_kernel_wrapper_functional HOP in the | ||
# graph (that we have yet to figure out how to serialize). | ||
def functional_decomp( # type: ignore[no-untyped-def] | ||
mode, _, types, args, kwargs | ||
): | ||
with mode: | ||
return fn(*args, **kwargs) | ||
|
||
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) | ||
return result | ||
|
||
if fn is None: | ||
return dec | ||
else: | ||
return dec(fn) | ||
|
||
|
||
def capture_triton(triton_kernel: Callable, /) -> Callable: | ||
"""Allows capture of a triton kernel into a graph via make_fx or | ||
non-strict export (coming soon). | ||
These technologies perform Dispatcher-based tracing (via | ||
``__torch_dispatch__``) and cannot see calls to raw triton kernels. | ||
The ``capture_triton`` API returns a new callable that can actually | ||
be traced into a graph. | ||
Examples: | ||
>>> # xdoctest: +SKIP | ||
>>> import torch | ||
>>> import triton | ||
>>> from triton import language as tl | ||
>>> from torch.fx.experimental.proxy_tensor import make_fx | ||
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton | ||
>>> | ||
>>> @triton.jit | ||
>>> def add_kernel( | ||
>>> in_ptr0, | ||
>>> in_ptr1, | ||
>>> out_ptr, | ||
>>> n_elements, | ||
>>> BLOCK_SIZE: "tl.constexpr", | ||
>>> ): | ||
>>> pid = tl.program_id(axis=0) | ||
>>> block_start = pid * BLOCK_SIZE | ||
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
>>> mask = offsets < n_elements | ||
>>> x = tl.load(in_ptr0 + offsets, mask=mask) | ||
>>> y = tl.load(in_ptr1 + offsets, mask=mask) | ||
>>> output = x + y | ||
>>> tl.store(out_ptr + offsets, output, mask=mask) | ||
>>> | ||
>>> def add(x, y): | ||
>>> output = torch.empty_like(x) | ||
>>> n_elements = output.numel() | ||
>>> | ||
>>> def grid_fn(meta): | ||
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
>>> | ||
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) | ||
>>> return output | ||
>>> | ||
>>> x = torch.randn(3, device="cuda") | ||
>>> y = torch.randn(3, device="cuda") | ||
>>> gm = make_fx(add)(x, y) | ||
>>> print(gm.code) | ||
>>> # def forward(self, x_1, y_1): | ||
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) | ||
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( | ||
>>> # kernel_idx = 0, constant_args_idx = 0, | ||
>>> # grid = [(1, 1, 1)], kwargs = { | ||
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, | ||
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16 | ||
>>> # }) | ||
>>> # return empty_like | ||
""" | ||
from triton.runtime.autotuner import Autotuner | ||
from triton.runtime.jit import JITFunction | ||
|
||
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper | ||
|
||
if not isinstance(triton_kernel, (JITFunction, Autotuner)): | ||
raise RuntimeError( | ||
"capture_triton only works on functions annotated with triton.jit or triton.autotune" | ||
) | ||
return TraceableTritonKernelWrapper(triton_kernel, None, None) |