diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py index cfcbdce761..ea44a88be5 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py @@ -1,26 +1,21 @@ from typing import Any, Dict, Optional, Sequence, Tuple import torch -from torch._custom_op.impl import custom_op +import torch._custom_ops as library from torch.fx.node import Argument, Target from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution from torch_tensorrt.fx.converter_registry import tensorrt_converter from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -@custom_op( - qualname="tensorrt::einsum", - manual_schema="(str equation, Tensor[] tensors) -> Tensor", +library.custom_op( + "tensorrt::einsum", + "(str equation, Tensor[] tensors) -> Tensor", ) -def einsum(equation, tensors): # type: ignore[no-untyped-def] - # Defines operator schema, name, namespace, and function header - ... -@einsum.impl("cpu") # type: ignore[misc] -@einsum.impl("cuda") # type: ignore[misc] -@einsum.impl_abstract() # type: ignore[misc] +@library.impl("tensorrt::einsum") # type: ignore[misc] +@library.impl_abstract("tensorrt::einsum") # type: ignore[misc] def einsum_generic( *args: Any, **kwargs: Any, diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py index 6db2664efb..0fb8e89414 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional, Tuple import torch -from torch._custom_op.impl import custom_op +import torch._custom_ops as library from torch.fx.node import Argument, Target from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution from torch_tensorrt.fx.converter_registry import tensorrt_converter @@ -20,13 +20,10 @@ # types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op # Then, create a placeholder function with no operations, but having the same schema and naming as that # used in the decorator -@custom_op( - qualname="tensorrt::maxpool1d", - manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor", +library.custom_op( + "tensorrt::maxpool1d", + "(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor", ) -def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ignore[no-untyped-def] - # Defines operator schema, name, namespace, and function header - ... # 2. The Generic Implementation @@ -36,9 +33,8 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ig # is desirable. If the operator to replace is a custom module you've written, then add its Torch # implementation here. Note that the function header to the generic function can have specific arguments # as in the above placeholder -@maxpool1d.impl("cpu") # type: ignore[misc] -@maxpool1d.impl("cuda") # type: ignore[misc] -@maxpool1d.impl_abstract() # type: ignore[misc] +@library.impl("tensorrt::maxpool1d") # type: ignore[misc] +@library.impl_abstract("tensorrt::maxpool1d") # type: ignore[misc] def maxpool1d_generic( *args: Any, **kwargs: Any, @@ -69,7 +65,7 @@ def maxpool1d_generic( # "bias": bias, # ... # -@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) # type: ignore def maxpool1d_insertion_fn( gm: torch.fx.GraphModule, node: torch.fx.Node,