diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4d9547d3ed..790acc9481 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2505,3 +2505,27 @@ def upsample_bilinear2d( resize_mode="bilinear", align_corners=args_bounds_check(args, 2), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sort.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_sort( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.topk.sort( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + dim=args_bounds_check(args, 1, -1), + descending=args_bounds_check(args, 2, False), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index a9e11cc537..41f6f990f2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple, Union import tensorrt as trt from torch.fx.node import Target @@ -101,3 +101,36 @@ def argmin( return argmax_argmin( ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim ) + + +def sort( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + descending: bool, + return_indices: bool = True, +) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]: + if descending: + topk_layer = ctx.net.add_topk( + input, + trt.TopKOperation.MAX, + input.shape[dim], + get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), + ) + else: + topk_layer = ctx.net.add_topk( + input, + trt.TopKOperation.MIN, + input.shape[dim], + get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), + ) + + set_layer_name(topk_layer, target, name, source_ir) + + if return_indices: + return topk_layer.get_output(0), topk_layer.get_output(1) + else: + return topk_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py new file mode 100644 index 0000000000..8bb9bc214e --- /dev/null +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSortConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 0, True), + ((2, 3, 4, 5), 1, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ((1, 5, 2, 1), -1, True), + ((1, 2, 5, 3), -2, False), + ((6, 2, 1, 3), -4, True), + ] + ) + def test_sort(self, input_shape, dim, descending): + class Sort(nn.Module): + def forward(self, x): + return torch.ops.aten.sort.default(x, dim, descending) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sort(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()