Skip to content

Commit ee039c0

Browse files
zou3519pytorchmergebot
authored andcommitted
[custom_op] triton_op API V0 (#130637)
This is the initial version of an API to create custom operators whose implementations are backed by triton kernels. While user-defined triton kernels work out-of-the-box with triton kernels, you may wish to construct a custom operator if you need to compose with other PyTorch subsystems, like Tensor subclasses or vmap. I'm hoping to get design feedback on this and ship it so that we can begin experimenting with customers. Test Plan: - new tests Pull Request resolved: #130637 Approved by: https://github.com/albanD
1 parent 6beec34 commit ee039c0

File tree

4 files changed

+240
-86
lines changed

4 files changed

+240
-86
lines changed

test/inductor/test_triton_kernels.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import torch._inductor.test_case
1111

1212
from torch._higher_order_ops.triton_kernel_wrap import (
13-
capture_triton,
1413
generate_ttir,
1514
triton_kernel_wrapper_functional,
1615
triton_kernel_wrapper_mutation,
1716
)
1817
from torch._inductor import metrics
1918
from torch._inductor.utils import run_and_get_code
19+
from torch._library import capture_triton
2020
from torch.testing._internal import common_utils
2121
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM
2222

@@ -2261,7 +2261,52 @@ def fwd_kernel(
22612261
setattr(MutationTests, name, fn)
22622262

22632263

2264+
class CustomOpTests(torch._inductor.test_case.TestCase):
2265+
"""Tests for custom ops wrapping triton kernels"""
2266+
2267+
@requires_gpu
2268+
@common_utils.parametrize("autotuned", [False, True])
2269+
def test_add_kernel(self, autotuned):
2270+
from torch._inductor.utils import run_and_get_code
2271+
2272+
libname = "my_cool_namespace"
2273+
opname = "my_triton_operator"
2274+
2275+
@torch._library.triton_op(f"{libname}::{opname}", mutates_args={})
2276+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2277+
output = torch.empty_like(x)
2278+
n_elements = output.numel()
2279+
2280+
def grid(meta):
2281+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2282+
2283+
if autotuned:
2284+
capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
2285+
else:
2286+
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
2287+
return output
2288+
2289+
def f(x, y):
2290+
return add(x, y)
2291+
2292+
x = torch.randn(3, device="cuda")
2293+
y = torch.randn(3, device="cuda")
2294+
2295+
out = f(x, y)
2296+
expected = x + y
2297+
self.assertEqual(out, expected)
2298+
out_compiled, codes = run_and_get_code(torch.compile(f), x, y)
2299+
self.assertEqual(out_compiled, expected)
2300+
self.assertEqual(len(codes), 1)
2301+
2302+
# Check that we decomposed the operator away
2303+
code = "\n".join(codes[0])
2304+
self.assertNotIn(libname, code)
2305+
self.assertNotIn(opname, code)
2306+
2307+
22642308
common_utils.instantiate_parametrized_tests(KernelTests)
2309+
common_utils.instantiate_parametrized_tests(CustomOpTests)
22652310

22662311

22672312
if __name__ == "__main__":

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -940,83 +940,10 @@ def call_triton_kernel(self, variable, args, kwargs, tx):
940940

941941

942942
###############################################################################
943-
# capture_triton API that makes a user-defined triton kernel traceable into
943+
# Helpers for capture_triton API that makes a user-defined triton kernel traceable into
944944
# a graph via make_fx or non-strict export (coming soon)
945945

946946

947-
def capture_triton(triton_kernel, /):
948-
"""Allows capture of a triton kernel into a graph via make_fx or
949-
non-strict export (coming soon).
950-
951-
These technologies perform Dispatcher-based tracing (via
952-
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
953-
The ``capture_triton`` API returns a new callable that can actually
954-
be traced into a graph.
955-
956-
Examples:
957-
958-
>>> # xdoctest: +SKIP
959-
>>> import torch
960-
>>> import triton
961-
>>> from triton import language as tl
962-
>>> from torch.fx.experimental.proxy_tensor import make_fx
963-
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
964-
>>>
965-
>>> @triton.jit
966-
>>> def add_kernel(
967-
>>> in_ptr0,
968-
>>> in_ptr1,
969-
>>> out_ptr,
970-
>>> n_elements,
971-
>>> BLOCK_SIZE: "tl.constexpr",
972-
>>> ):
973-
>>> pid = tl.program_id(axis=0)
974-
>>> block_start = pid * BLOCK_SIZE
975-
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
976-
>>> mask = offsets < n_elements
977-
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
978-
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
979-
>>> output = x + y
980-
>>> tl.store(out_ptr + offsets, output, mask=mask)
981-
>>>
982-
>>> def add(x, y):
983-
>>> output = torch.empty_like(x)
984-
>>> n_elements = output.numel()
985-
>>>
986-
>>> def grid_fn(meta):
987-
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
988-
>>>
989-
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
990-
>>> return output
991-
>>>
992-
>>> x = torch.randn(3, device="cuda")
993-
>>> y = torch.randn(3, device="cuda")
994-
>>> gm = make_fx(add)(x, y)
995-
>>> print(gm.code)
996-
>>> # def forward(self, x_1, y_1):
997-
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
998-
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
999-
>>> # kernel_idx = 0, constant_args_idx = 0,
1000-
>>> # grid = [(1, 1, 1)], kwargs = {
1001-
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
1002-
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
1003-
>>> # })
1004-
>>> # return empty_like
1005-
1006-
"""
1007-
from triton.runtime.autotuner import Autotuner
1008-
from triton.runtime.jit import JITFunction
1009-
1010-
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
1011-
raise RuntimeError(
1012-
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
1013-
)
1014-
return TraceableTritonKernelWrapper(triton_kernel, None, None)
1015-
1016-
1017-
from ..fx._symbolic_trace import is_fx_tracing
1018-
1019-
1020947
class TracingTritonHOPifier(TritonHOPifier):
1021948
def raise_unsupported(self, msg):
1022949
raise RuntimeError(msg)
@@ -1071,20 +998,9 @@ def __getitem__(self, *args):
1071998
return tracing_triton_hopifier_singleton.call_getitem(self, args)
1072999

10731000
def run(self, *args, **kwargs):
1074-
import torch._dynamo
1075-
1076-
if not is_fx_tracing() or torch._dynamo.is_compiling():
1077-
assert self.kernel is not None
1078-
return self.kernel.run(*args, **kwargs)
10791001
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
10801002

10811003
def __call__(self, *args, **kwargs):
1082-
import torch._dynamo
1083-
1084-
if not is_fx_tracing() or torch._dynamo.is_compiling():
1085-
assert self.kernel is not None
1086-
return self.kernel.run(*args, **kwargs, grid=self.grid, warmup=False)
1087-
10881004
return tracing_triton_hopifier_singleton.call_triton_kernel(
10891005
self, args, kwargs, None
10901006
)

torch/_library/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
import torch._library.utils
55

66
from torch._library.fake_class_registry import register_fake_class
7+
from torch._library.triton import capture_triton, triton_op

torch/_library/triton.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from typing import Callable, Iterable, Optional, Union
2+
3+
from .custom_ops import custom_op
4+
5+
6+
def triton_op(
7+
name: str,
8+
fn: Optional[Callable] = None,
9+
/,
10+
*,
11+
mutates_args: Union[str, Iterable[str]],
12+
schema: Optional[str] = None,
13+
) -> Callable:
14+
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
15+
16+
Use this instead of :func:`torch.library.custom_op` when the implementation
17+
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
18+
custom operators as opaque (:func:`torch.compile` and
19+
:func:`torch.export.export` will never trace into them), but ``triton_op``
20+
makes the implementation visible to these subsystems, allowing them
21+
to optimize the triton kernel(s).
22+
23+
Note that ``fn`` must only consist of calls to PyTorch-understood
24+
operators and triton kernels. Any triton kernels called inside ``fn``
25+
must be wrapped in a call to :func:`torch._library.capture_triton``.
26+
27+
Args:
28+
name (str): A name for the custom op that looks like "{namespace}::{name}",
29+
e.g. "mylib::my_linear". The name is used as the op's stable identifier
30+
in PyTorch subsystems (e.g. torch.export, FX graphs).
31+
To avoid name collisions, please use your project name as the namespace;
32+
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
33+
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
34+
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
35+
it pessimistically assumes that all inputs to the operator are being mutated.
36+
schema (None | str): A schema string for the operator. If None
37+
(recommended) we'll infer a schema for the operator from its type
38+
annotations. We recommend letting us infer a schema unless you
39+
have a specific reason not to.
40+
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
41+
42+
Example::
43+
44+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
45+
>>> import torch
46+
>>> from torch._library import triton_op, capture_triton
47+
>>>
48+
>>> import triton
49+
>>> from triton import language as tl
50+
>>>
51+
>>> @triton.jit
52+
>>> def add_kernel(
53+
>>> in_ptr0,
54+
>>> in_ptr1,
55+
>>> out_ptr,
56+
>>> n_elements,
57+
>>> BLOCK_SIZE: "tl.constexpr",
58+
>>> ):
59+
>>> pid = tl.program_id(axis=0)
60+
>>> block_start = pid * BLOCK_SIZE
61+
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
62+
>>> mask = offsets < n_elements
63+
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
64+
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
65+
>>> output = x + y
66+
>>> tl.store(out_ptr + offsets, output, mask=mask)
67+
>>>
68+
>>> @triton_op("mylib::add", mutates_args={})
69+
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
70+
>>> output = torch.empty_like(x)
71+
>>> n_elements = output.numel()
72+
>>>
73+
>>> def grid(meta):
74+
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
75+
>>>
76+
>>> # NB: we need to wrap the triton kernel in a call to capture_triton
77+
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
78+
>>> return output
79+
>>>
80+
>>> @torch.compile
81+
>>> def f(x, y):
82+
>>> return add(x, y)
83+
>>>
84+
>>> x = torch.randn(3, device="cuda")
85+
>>> y = torch.randn(3, device="cuda")
86+
>>>
87+
>>> z = f(x, y)
88+
>>> assert torch.allclose(z, x + y)
89+
90+
"""
91+
92+
def dec(fn: Callable) -> Callable:
93+
result = custom_op(name, fn, mutates_args=mutates_args)
94+
from .._subclasses.functional_tensor import FunctionalTensorMode
95+
96+
# We require that the user pass us a function that is make_fx traceable,
97+
# so we can just register it as the Fake/meta kernel.
98+
result.register_fake(fn)
99+
100+
# We decompose the operator when FunctionalTensorMode is active.
101+
# The goal is to decompose the operator in AOTDispatcher.
102+
# - With torch.compile, this means that the backend (usually Inductor)
103+
# can see a call to the triton kernel(s) and so it can directly optimize
104+
# them by inlining them into the lowering process.
105+
# - With post-dispatch torch.export, this means that there will
106+
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
107+
# graph (that we have yet to figure out how to serialize).
108+
def functional_decomp( # type: ignore[no-untyped-def]
109+
mode, _, types, args, kwargs
110+
):
111+
with mode:
112+
return fn(*args, **kwargs)
113+
114+
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
115+
return result
116+
117+
if fn is None:
118+
return dec
119+
else:
120+
return dec(fn)
121+
122+
123+
def capture_triton(triton_kernel: Callable, /) -> Callable:
124+
"""Allows capture of a triton kernel into a graph via make_fx or
125+
non-strict export (coming soon).
126+
127+
These technologies perform Dispatcher-based tracing (via
128+
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
129+
The ``capture_triton`` API returns a new callable that can actually
130+
be traced into a graph.
131+
132+
Examples:
133+
134+
>>> # xdoctest: +SKIP
135+
>>> import torch
136+
>>> import triton
137+
>>> from triton import language as tl
138+
>>> from torch.fx.experimental.proxy_tensor import make_fx
139+
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
140+
>>>
141+
>>> @triton.jit
142+
>>> def add_kernel(
143+
>>> in_ptr0,
144+
>>> in_ptr1,
145+
>>> out_ptr,
146+
>>> n_elements,
147+
>>> BLOCK_SIZE: "tl.constexpr",
148+
>>> ):
149+
>>> pid = tl.program_id(axis=0)
150+
>>> block_start = pid * BLOCK_SIZE
151+
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
152+
>>> mask = offsets < n_elements
153+
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
154+
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
155+
>>> output = x + y
156+
>>> tl.store(out_ptr + offsets, output, mask=mask)
157+
>>>
158+
>>> def add(x, y):
159+
>>> output = torch.empty_like(x)
160+
>>> n_elements = output.numel()
161+
>>>
162+
>>> def grid_fn(meta):
163+
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
164+
>>>
165+
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
166+
>>> return output
167+
>>>
168+
>>> x = torch.randn(3, device="cuda")
169+
>>> y = torch.randn(3, device="cuda")
170+
>>> gm = make_fx(add)(x, y)
171+
>>> print(gm.code)
172+
>>> # def forward(self, x_1, y_1):
173+
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
174+
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
175+
>>> # kernel_idx = 0, constant_args_idx = 0,
176+
>>> # grid = [(1, 1, 1)], kwargs = {
177+
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
178+
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
179+
>>> # })
180+
>>> # return empty_like
181+
182+
"""
183+
from triton.runtime.autotuner import Autotuner
184+
from triton.runtime.jit import JITFunction
185+
186+
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
187+
188+
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
189+
raise RuntimeError(
190+
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
191+
)
192+
return TraceableTritonKernelWrapper(triton_kernel, None, None)

0 commit comments

Comments
 (0)