|
| 1 | +from typing import Callable, Iterable, Optional, Union |
| 2 | + |
| 3 | +from .custom_ops import custom_op |
| 4 | + |
| 5 | + |
| 6 | +def triton_op( |
| 7 | + name: str, |
| 8 | + fn: Optional[Callable] = None, |
| 9 | + /, |
| 10 | + *, |
| 11 | + mutates_args: Union[str, Iterable[str]], |
| 12 | + schema: Optional[str] = None, |
| 13 | +) -> Callable: |
| 14 | + """Create a custom operator whose implementation is backed by 1+ triton kernels. |
| 15 | +
|
| 16 | + Use this instead of :func:`torch.library.custom_op` when the implementation |
| 17 | + consists of 1+ triton kernels. :func:`torch.library.custom_op` treats |
| 18 | + custom operators as opaque (:func:`torch.compile` and |
| 19 | + :func:`torch.export.export` will never trace into them), but ``triton_op`` |
| 20 | + makes the implementation visible to these subsystems, allowing them |
| 21 | + to optimize the triton kernel(s). |
| 22 | +
|
| 23 | + Note that ``fn`` must only consist of calls to PyTorch-understood |
| 24 | + operators and triton kernels. Any triton kernels called inside ``fn`` |
| 25 | + must be wrapped in a call to :func:`torch._library.capture_triton``. |
| 26 | +
|
| 27 | + Args: |
| 28 | + name (str): A name for the custom op that looks like "{namespace}::{name}", |
| 29 | + e.g. "mylib::my_linear". The name is used as the op's stable identifier |
| 30 | + in PyTorch subsystems (e.g. torch.export, FX graphs). |
| 31 | + To avoid name collisions, please use your project name as the namespace; |
| 32 | + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. |
| 33 | + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. |
| 34 | + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", |
| 35 | + it pessimistically assumes that all inputs to the operator are being mutated. |
| 36 | + schema (None | str): A schema string for the operator. If None |
| 37 | + (recommended) we'll infer a schema for the operator from its type |
| 38 | + annotations. We recommend letting us infer a schema unless you |
| 39 | + have a specific reason not to. |
| 40 | + Example: "(Tensor x, int y) -> (Tensor, Tensor)". |
| 41 | +
|
| 42 | + Example:: |
| 43 | +
|
| 44 | + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| 45 | + >>> import torch |
| 46 | + >>> from torch._library import triton_op, capture_triton |
| 47 | + >>> |
| 48 | + >>> import triton |
| 49 | + >>> from triton import language as tl |
| 50 | + >>> |
| 51 | + >>> @triton.jit |
| 52 | + >>> def add_kernel( |
| 53 | + >>> in_ptr0, |
| 54 | + >>> in_ptr1, |
| 55 | + >>> out_ptr, |
| 56 | + >>> n_elements, |
| 57 | + >>> BLOCK_SIZE: "tl.constexpr", |
| 58 | + >>> ): |
| 59 | + >>> pid = tl.program_id(axis=0) |
| 60 | + >>> block_start = pid * BLOCK_SIZE |
| 61 | + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 62 | + >>> mask = offsets < n_elements |
| 63 | + >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| 64 | + >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| 65 | + >>> output = x + y |
| 66 | + >>> tl.store(out_ptr + offsets, output, mask=mask) |
| 67 | + >>> |
| 68 | + >>> @triton_op("mylib::add", mutates_args={}) |
| 69 | + >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 70 | + >>> output = torch.empty_like(x) |
| 71 | + >>> n_elements = output.numel() |
| 72 | + >>> |
| 73 | + >>> def grid(meta): |
| 74 | + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 75 | + >>> |
| 76 | + >>> # NB: we need to wrap the triton kernel in a call to capture_triton |
| 77 | + >>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| 78 | + >>> return output |
| 79 | + >>> |
| 80 | + >>> @torch.compile |
| 81 | + >>> def f(x, y): |
| 82 | + >>> return add(x, y) |
| 83 | + >>> |
| 84 | + >>> x = torch.randn(3, device="cuda") |
| 85 | + >>> y = torch.randn(3, device="cuda") |
| 86 | + >>> |
| 87 | + >>> z = f(x, y) |
| 88 | + >>> assert torch.allclose(z, x + y) |
| 89 | +
|
| 90 | + """ |
| 91 | + |
| 92 | + def dec(fn: Callable) -> Callable: |
| 93 | + result = custom_op(name, fn, mutates_args=mutates_args) |
| 94 | + from .._subclasses.functional_tensor import FunctionalTensorMode |
| 95 | + |
| 96 | + # We require that the user pass us a function that is make_fx traceable, |
| 97 | + # so we can just register it as the Fake/meta kernel. |
| 98 | + result.register_fake(fn) |
| 99 | + |
| 100 | + # We decompose the operator when FunctionalTensorMode is active. |
| 101 | + # The goal is to decompose the operator in AOTDispatcher. |
| 102 | + # - With torch.compile, this means that the backend (usually Inductor) |
| 103 | + # can see a call to the triton kernel(s) and so it can directly optimize |
| 104 | + # them by inlining them into the lowering process. |
| 105 | + # - With post-dispatch torch.export, this means that there will |
| 106 | + # be a call(s) to the triton_kernel_wrapper_functional HOP in the |
| 107 | + # graph (that we have yet to figure out how to serialize). |
| 108 | + def functional_decomp( # type: ignore[no-untyped-def] |
| 109 | + mode, _, types, args, kwargs |
| 110 | + ): |
| 111 | + with mode: |
| 112 | + return fn(*args, **kwargs) |
| 113 | + |
| 114 | + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) |
| 115 | + return result |
| 116 | + |
| 117 | + if fn is None: |
| 118 | + return dec |
| 119 | + else: |
| 120 | + return dec(fn) |
| 121 | + |
| 122 | + |
| 123 | +def capture_triton(triton_kernel: Callable, /) -> Callable: |
| 124 | + """Allows capture of a triton kernel into a graph via make_fx or |
| 125 | + non-strict export (coming soon). |
| 126 | +
|
| 127 | + These technologies perform Dispatcher-based tracing (via |
| 128 | + ``__torch_dispatch__``) and cannot see calls to raw triton kernels. |
| 129 | + The ``capture_triton`` API returns a new callable that can actually |
| 130 | + be traced into a graph. |
| 131 | +
|
| 132 | + Examples: |
| 133 | +
|
| 134 | + >>> # xdoctest: +SKIP |
| 135 | + >>> import torch |
| 136 | + >>> import triton |
| 137 | + >>> from triton import language as tl |
| 138 | + >>> from torch.fx.experimental.proxy_tensor import make_fx |
| 139 | + >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton |
| 140 | + >>> |
| 141 | + >>> @triton.jit |
| 142 | + >>> def add_kernel( |
| 143 | + >>> in_ptr0, |
| 144 | + >>> in_ptr1, |
| 145 | + >>> out_ptr, |
| 146 | + >>> n_elements, |
| 147 | + >>> BLOCK_SIZE: "tl.constexpr", |
| 148 | + >>> ): |
| 149 | + >>> pid = tl.program_id(axis=0) |
| 150 | + >>> block_start = pid * BLOCK_SIZE |
| 151 | + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 152 | + >>> mask = offsets < n_elements |
| 153 | + >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| 154 | + >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| 155 | + >>> output = x + y |
| 156 | + >>> tl.store(out_ptr + offsets, output, mask=mask) |
| 157 | + >>> |
| 158 | + >>> def add(x, y): |
| 159 | + >>> output = torch.empty_like(x) |
| 160 | + >>> n_elements = output.numel() |
| 161 | + >>> |
| 162 | + >>> def grid_fn(meta): |
| 163 | + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 164 | + >>> |
| 165 | + >>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) |
| 166 | + >>> return output |
| 167 | + >>> |
| 168 | + >>> x = torch.randn(3, device="cuda") |
| 169 | + >>> y = torch.randn(3, device="cuda") |
| 170 | + >>> gm = make_fx(add)(x, y) |
| 171 | + >>> print(gm.code) |
| 172 | + >>> # def forward(self, x_1, y_1): |
| 173 | + >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) |
| 174 | + >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( |
| 175 | + >>> # kernel_idx = 0, constant_args_idx = 0, |
| 176 | + >>> # grid = [(1, 1, 1)], kwargs = { |
| 177 | + >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, |
| 178 | + >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 |
| 179 | + >>> # }) |
| 180 | + >>> # return empty_like |
| 181 | +
|
| 182 | + """ |
| 183 | + from triton.runtime.autotuner import Autotuner |
| 184 | + from triton.runtime.jit import JITFunction |
| 185 | + |
| 186 | + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper |
| 187 | + |
| 188 | + if not isinstance(triton_kernel, (JITFunction, Autotuner)): |
| 189 | + raise RuntimeError( |
| 190 | + "capture_triton only works on functions annotated with triton.jit or triton.autotune" |
| 191 | + ) |
| 192 | + return TraceableTritonKernelWrapper(triton_kernel, None, None) |
0 commit comments