Skip to content

❓ [Question] How wo you export a triton kernel with model to a serialized engine that can be run in c++? #3469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cmgreen210 opened this issue Apr 11, 2025 · 8 comments
Assignees
Labels
question Further information is requested

Comments

@cmgreen210
Copy link

❓ Question

How wo you export a triton kernel with model to a serialized engine that can be run in c++?

What you have already tried

Read through python examples.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@cmgreen210 cmgreen210 added the question Further information is requested label Apr 11, 2025
@narendasan
Copy link
Collaborator

narendasan commented Apr 11, 2025

You can use this API from TensorRT and https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/tensorrt.plugin/trt_plugin_aot_impl/index.html

Tutorial here: https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/pluginGuide.html#providing-an-ahead-of-time-aot-implementation

We have some auto-generation systems for these sort of plugins in progress that you can look at here https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/auto_generate_plugins.html but they do not directly support AOT plugins currently. We will be updating the examples with how you can do this workflow once its ready but in short the generated plugin converter needs an additional aot=True flag so you might be able to copy the converter generation code (https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py) and add the patch yourself for now.

The expected workflow will be something like:

from typing import Tuple

import tensorrt_bindings.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl


@triton.jit
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
     ...


@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=())  # type: ignore[misc]
def elementwise_scale_mul(
    X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
     ...

@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
    ...

# All of the above is required to enable torch.export

torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True, requires_output_allocator=False) # Final workflow

# For a WAR this will be replaced by 

torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "torchtrt_ex::elementwise_scale_mul"
)

your_patched_generate_plugin_converter(
    "torchtrt_ex::elementwise_scale_mul",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
) # Generates basically the same code but has the `aot=True` flag 


# Tell TRT how to call the PTX of your kernel inside the engine without Python 

@trtp.aot_impl("torchtrt_ex::elementwise_scale_mul")
def elementwise_scale_mul_aot(
    inp0: trtp.TensorDesc, inp1: trtp.TensorDesc, b: float = 0.2, a: int = 2, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
    ...

compiled_module = torch_tensorrt.compile(my_module, ...)

ts_module = torch.jit.trace(compiled_module)
torch.jit.save(ts_module, "my_module.ts") # Executable in C++ 

# Alternatively 

serialized_engine = torch_tensorrt.convert_method_to_trt_engine(compiled_module)

cc: @bowang007

@cmgreen210
Copy link
Author

Thanks @narendasan for the detailed response. I was also looking to avoid using libtorch and instead just rely on the nvinfer libraries at inference time. Is this possible too?

@cmgreen210
Copy link
Author

cmgreen210 commented Apr 12, 2025

Tried a pretty simple example:

from typing import Tuple, Union

import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl

from custom_generate_plugin_converter import generate_plugin_converter


@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x + 1, mask=mask)


@torch.library.custom_op("torchtrt_ex::add_one", mutates_args=())  # type: ignore[misc]
def add_one(
    X: torch.Tensor,
) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda

    # Create output tensor
    Y = torch.empty_like(X)

    # Define block size
    BLOCK_SIZE = 1024

    # Grid of programs
    grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)

    # Launch the kernel
    add_one_kernel[grid](X, Y,  BLOCK_SIZE=BLOCK_SIZE)

    return Y


@torch.library.register_fake("torchtrt_ex::add_one")
def _(X: torch.Tensor) -> torch.Tensor:
    return X


torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "torchtrt_ex::add_one"
)

@trtp.aot_impl("torchtrt_ex::add_one")
def add_plugin_aot_impl(
    X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:

    type_str = "fp32" if X.dtype == trt.float32 else "fp16"

    src = triton.compiler.ASTSource(
        fn=add_one_kernel,
        signature=f"*{type_str},i32,*{type_str}",
        constexprs={
            "BLOCK_SIZE": 256,
        },
    )

    compiled_kernel = triton.compile(src)

    N = X.shape_expr.numel()
    launch_params = trtp.KernelLaunchParams()

    # grid dims
    launch_params.grid_x = trtp.cdiv(N, 256)
    # block dims
    launch_params.block_x = compiled_kernel.metadata.num_warps * 32
    # shared memory
    launch_params.shared_mem = compiled_kernel.metadata.shared

    extra_args = trtp.SymIntExprs(1)
    extra_args[0] = trtp.SymInt32(N)

    return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args

generate_plugin_converter(
    "torchtrt_ex::add_one",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
)

class MyModel(torch.nn.Module):  # type: ignore[misc]
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = torch.ops.torchtrt_ex.add_one.default(x)

        return res


my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)

with torch_tensorrt.logging.errors():
    model_trt = torch_tensorrt.compile(
        my_model, inputs=[m], debug=True, min_block_size=1
    )
    for i in range(300):
        res = model_trt(m)
        assert torch.allclose(res, my_model(m))

print("Ran with custom plugin!")

running into:

DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(64, 64), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (add_one: (64, 64)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.137904
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
[ERROR] Exception caught in get_launch_params(): CompilationError: at 1:0:
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
^
IndexError('list assignment index out of range')

The custom converter is same code as in torch_tensorrt with the aot=True added to add_plugin

my python env is setup with uv and config:

[project]
name = "triton-exp"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "ipython>=9.1.0",
    "torch>=2.6.0",
    "triton-windows>=3.3.0a0.post17",
    "torch-tensorrt==2.7.0.dev20250323+cu128 ; sys_platform == 'linux' or sys_platform == 'win32'",
    "tensorrt>=10.7.0.post1; sys_platform == 'linux' or sys_platform == 'win32'",
    "black>=25.1.0",
]

[tool.uv.sources]
torch = [
    {index = "pytorch-cuda-nightly", marker = "sys_platform == 'win32' or sys_platform == 'linux'" },
]
torch-tensorrt = [
    {index = "pytorch-cuda-nightly", marker = "sys_platform == 'win32' or sys_platform == 'linux'" },
]

[[tool.uv.index]]
name = "pytorch-cuda-nightly"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true

@cmgreen210
Copy link
Author

I've been able to get the error above fixed with the below code (messed with signatures a bit):

from typing import Tuple, Union

import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl

from custom_generate_plugin_converter import generate_plugin_converter


@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    output = x + 1
    tl.store(y_ptr + offsets, output, mask=mask)


@torch.library.custom_op("torchtrt_ex::add_one", mutates_args=())  # type: ignore[misc]
def add_one(
    X: torch.Tensor,
) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda

    # Create output tensor
    Y = torch.empty_like(X)

    # Define block size
    BLOCK_SIZE = 1024

    # Grid of programs
    grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)

    # Launch the kernel
    add_one_kernel[grid](X, X.numel(), Y,  BLOCK_SIZE=BLOCK_SIZE)

    return Y


@torch.library.register_fake("torchtrt_ex::add_one")
def _(X: torch.Tensor) -> torch.Tensor:
    return X


torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "torchtrt_ex::add_one"
)

@trtp.aot_impl("torchtrt_ex::add_one")
def add_plugin_aot_impl(
    X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:

    type_str = "fp32" if X.dtype == trt.float32 else "fp16"

    src = triton.compiler.ASTSource(
        fn=add_one_kernel,
        signature=f"*{type_str},i32,*{type_str},i32",
        constexprs={
            "BLOCK_SIZE": 1024,
        },
    )

    compiled_kernel = triton.compile(src)

    N = X.shape_expr.numel()
    launch_params = trtp.KernelLaunchParams()

    # grid dims
    launch_params.grid_x = trtp.cdiv(N, 1024)
    # block dims
    launch_params.block_x = compiled_kernel.metadata.num_warps * 32
    # shared memory
    launch_params.shared_mem = compiled_kernel.metadata.shared

    extra_args = trtp.SymIntExprs(1)
    extra_args[0] = trtp.SymInt32(N)

    return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args

generate_plugin_converter(
    "torchtrt_ex::add_one",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
)

class MyModel(torch.nn.Module):  # type: ignore[misc]
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = torch.ops.torchtrt_ex.add_one.default(x)

        return res


my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
assert my_model(m)[0][0] == 3.0

with torch_tensorrt.logging.debug():
    model_trt = torch_tensorrt.compile(
        my_model, inputs=[m], debug=True, min_block_size=1
    )
    for i in range(300):
        res = model_trt(m)
        assert torch.allclose(res, my_model(m))

print("Ran with custom plugin!")

But now changing this to layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=True) is not sufficient to get the code to run as written:

DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Using the standard execution runtime mode with cudagraphs=0.
DEBUG: [Torch-TensorRT] - Input shape changed None -> (64,64)
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [64, 64]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [64, 64]
ERROR: [Torch-TensorRT] - Error Code: 2: invalid argument
ERROR: [Torch-TensorRT] - [pluginV3Runner.cpp::nvinfer1::rt::cuda::PluginV3Runner::onShapeChange::98] Error Code 2: Internal Error (Assertion pluginUtils::isSuccess(status) failed. )

But if aot=False it all works fine.

@cmgreen210
Copy link
Author

@narendasan @bowang007 I made a reproducible example https://github.com/cmgreen210/tmp-pytorch-trt-plugin-example that works in linux. uv run main.py works fine but uv run main.py --aot hangs at the inference step of the compiled model. In the repo I've included logs from each of these runs.

My goal is to be able to run serialized tensorrt engines on end user nvidia devices (so not in a data center) without having to install a lot of extra packages/libraries/etc and ideally I'd also like to easily use triton kernels to speed things up hence this issue.

Any help is greatly appreciated!

@narendasan
Copy link
Collaborator

Thanks @narendasan for the detailed response. I was also looking to avoid using libtorch and instead just rely on the nvinfer libraries at inference time. Is this possible too?

Yes by using this API you will only get a engine out

serialized_engine = torch_tensorrt.convert_method_to_trt_engine(compiled_module)

@bowang007 can you look at this repro?

@bowang007
Copy link
Collaborator

Hi @cmgreen210
I was trying to run your example, I cannot run your example as well and I found there are some API usage that maybe you should udpate:

  1. The existing automatic plugin generation infrastructure clearly doesn't support AOT TRT plugin, maybe you shouldn't use it anymore
  2. the aot plugin impl signature in the example doesn't follow the argument requirements demonstrated in TensorRT doc, that's why there is an invalid argument issue.

I'm now creating a new demo for AOT plugin.

@cmgreen210
Copy link
Author

Thanks for your response @bowang007 - I think I'll wait for your AOT plugin demo before making any changes. Hopefully it will be available soon! 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants