Skip to content

Commit

Permalink
[custom_op] triton_op API V0 (#130637)
Browse files Browse the repository at this point in the history
This is the initial version of an API to create custom operators whose
implementations are backed by triton kernels. While user-defined triton
kernels work out-of-the-box with triton kernels, you may wish to
construct a custom operator if you need to compose with other PyTorch
subsystems, like Tensor subclasses or vmap.

I'm hoping to get design feedback on this and ship it so that we can
begin experimenting with customers.

Test Plan:
- new tests

Pull Request resolved: #130637
Approved by: https://github.com/albanD
  • Loading branch information
zou3519 authored and pytorchmergebot committed Jul 15, 2024
1 parent 6beec34 commit ee039c0
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 86 deletions.
47 changes: 46 additions & 1 deletion test/inductor/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch._inductor.test_case

from torch._higher_order_ops.triton_kernel_wrap import (
capture_triton,
generate_ttir,
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
)
from torch._inductor import metrics
from torch._inductor.utils import run_and_get_code
from torch._library import capture_triton
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM

Expand Down Expand Up @@ -2261,7 +2261,52 @@ def fwd_kernel(
setattr(MutationTests, name, fn)


class CustomOpTests(torch._inductor.test_case.TestCase):
"""Tests for custom ops wrapping triton kernels"""

@requires_gpu
@common_utils.parametrize("autotuned", [False, True])
def test_add_kernel(self, autotuned):
from torch._inductor.utils import run_and_get_code

libname = "my_cool_namespace"
opname = "my_triton_operator"

@torch._library.triton_op(f"{libname}::{opname}", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

if autotuned:
capture_triton(add_kernel_autotuned)[grid](x, y, output, n_elements)
else:
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output

def f(x, y):
return add(x, y)

x = torch.randn(3, device="cuda")
y = torch.randn(3, device="cuda")

out = f(x, y)
expected = x + y
self.assertEqual(out, expected)
out_compiled, codes = run_and_get_code(torch.compile(f), x, y)
self.assertEqual(out_compiled, expected)
self.assertEqual(len(codes), 1)

# Check that we decomposed the operator away
code = "\n".join(codes[0])
self.assertNotIn(libname, code)
self.assertNotIn(opname, code)


common_utils.instantiate_parametrized_tests(KernelTests)
common_utils.instantiate_parametrized_tests(CustomOpTests)


if __name__ == "__main__":
Expand Down
86 changes: 1 addition & 85 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,83 +940,10 @@ def call_triton_kernel(self, variable, args, kwargs, tx):


###############################################################################
# capture_triton API that makes a user-defined triton kernel traceable into
# Helpers for capture_triton API that makes a user-defined triton kernel traceable into
# a graph via make_fx or non-strict export (coming soon)


def capture_triton(triton_kernel, /):
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict export (coming soon).
These technologies perform Dispatcher-based tracing (via
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
The ``capture_triton`` API returns a new callable that can actually
be traced into a graph.
Examples:
>>> # xdoctest: +SKIP
>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid_fn(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> # kernel_idx = 0, constant_args_idx = 0,
>>> # grid = [(1, 1, 1)], kwargs = {
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
>>> # })
>>> # return empty_like
"""
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction

if not isinstance(triton_kernel, (JITFunction, Autotuner)):
raise RuntimeError(
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
)
return TraceableTritonKernelWrapper(triton_kernel, None, None)


from ..fx._symbolic_trace import is_fx_tracing


class TracingTritonHOPifier(TritonHOPifier):
def raise_unsupported(self, msg):
raise RuntimeError(msg)
Expand Down Expand Up @@ -1071,20 +998,9 @@ def __getitem__(self, *args):
return tracing_triton_hopifier_singleton.call_getitem(self, args)

def run(self, *args, **kwargs):
import torch._dynamo

if not is_fx_tracing() or torch._dynamo.is_compiling():
assert self.kernel is not None
return self.kernel.run(*args, **kwargs)
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)

def __call__(self, *args, **kwargs):
import torch._dynamo

if not is_fx_tracing() or torch._dynamo.is_compiling():
assert self.kernel is not None
return self.kernel.run(*args, **kwargs, grid=self.grid, warmup=False)

return tracing_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, None
)
1 change: 1 addition & 0 deletions torch/_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
import torch._library.utils

from torch._library.fake_class_registry import register_fake_class
from torch._library.triton import capture_triton, triton_op
192 changes: 192 additions & 0 deletions torch/_library/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Callable, Iterable, Optional, Union

from .custom_ops import custom_op


def triton_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
schema: Optional[str] = None,
) -> Callable:
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
Use this instead of :func:`torch.library.custom_op` when the implementation
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
custom operators as opaque (:func:`torch.compile` and
:func:`torch.export.export` will never trace into them), but ``triton_op``
makes the implementation visible to these subsystems, allowing them
to optimize the triton kernel(s).
Note that ``fn`` must only consist of calls to PyTorch-understood
operators and triton kernels. Any triton kernels called inside ``fn``
must be wrapped in a call to :func:`torch._library.capture_triton``.
Args:
name (str): A name for the custom op that looks like "{namespace}::{name}",
e.g. "mylib::my_linear". The name is used as the op's stable identifier
in PyTorch subsystems (e.g. torch.export, FX graphs).
To avoid name collisions, please use your project name as the namespace;
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
it pessimistically assumes that all inputs to the operator are being mutated.
schema (None | str): A schema string for the operator. If None
(recommended) we'll infer a schema for the operator from its type
annotations. We recommend letting us infer a schema unless you
have a specific reason not to.
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch._library import triton_op, capture_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> # NB: we need to wrap the triton kernel in a call to capture_triton
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>> return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>> return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
"""

def dec(fn: Callable) -> Callable:
result = custom_op(name, fn, mutates_args=mutates_args)
from .._subclasses.functional_tensor import FunctionalTensorMode

# We require that the user pass us a function that is make_fx traceable,
# so we can just register it as the Fake/meta kernel.
result.register_fake(fn)

# We decompose the operator when FunctionalTensorMode is active.
# The goal is to decompose the operator in AOTDispatcher.
# - With torch.compile, this means that the backend (usually Inductor)
# can see a call to the triton kernel(s) and so it can directly optimize
# them by inlining them into the lowering process.
# - With post-dispatch torch.export, this means that there will
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
# graph (that we have yet to figure out how to serialize).
def functional_decomp( # type: ignore[no-untyped-def]
mode, _, types, args, kwargs
):
with mode:
return fn(*args, **kwargs)

result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
return result

if fn is None:
return dec
else:
return dec(fn)


def capture_triton(triton_kernel: Callable, /) -> Callable:
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict export (coming soon).
These technologies perform Dispatcher-based tracing (via
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
The ``capture_triton`` API returns a new callable that can actually
be traced into a graph.
Examples:
>>> # xdoctest: +SKIP
>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid_fn(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> # kernel_idx = 0, constant_args_idx = 0,
>>> # grid = [(1, 1, 1)], kwargs = {
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
>>> # })
>>> # return empty_like
"""
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction

from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper

if not isinstance(triton_kernel, (JITFunction, Autotuner)):
raise RuntimeError(
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
)
return TraceableTritonKernelWrapper(triton_kernel, None, None)

0 comments on commit ee039c0

Please sign in to comment.