diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 99007e2a4d..23b5e006e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -152,15 +152,15 @@ def aten_ops_fmod( return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) # type: ignore[misc] -def aten_ops_gelu( +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.activation.gelu( + return impl.activation.relu( network, target, SourceIR.ATEN, @@ -169,61 +169,171 @@ def aten_ops_gelu( ) -@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] -def aten_ops_matmul( +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) +def aten_ops_sigmoid( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.matmul.matrix_multiply( + return impl.activation.sigmoid( network, target, SourceIR.ATEN, name, args[0], - args[1], ) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] -def aten_ops_layernorm( +@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) +def aten_ops_tanh( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.layer_norm( + return impl.activation.tanh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) +def aten_ops_leaky_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.leaky_relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, 0.01), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.elu.default) +def aten_ops_elu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.elu( + network, + target, + SourceIR.ATEN, + name, + args[0], + alpha=args_bounds_check(args, 1, 1.0), + beta=args_bounds_check(args, 2, None), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) +def aten_ops_softplus( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.softplus( + network, + target, + SourceIR.ATEN, + name, + args[0], + beta=args_bounds_check(args, 1, 1), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) +def aten_ops_clip( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.clip( + network, + target, + SourceIR.ATEN, + name, + args[0], + alpha=args_bounds_check(args, 1), + beta=args_bounds_check(args, 2), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) +def aten_ops_hard_sigmoid( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.hard_sigmoid( + network, + target, + SourceIR.ATEN, + name, + args[0], + alpha=args_bounds_check(args, 1, 1 / 6), + beta=args_bounds_check(args, 2, 1 / 2), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.matmul.matrix_multiply( network, target, SourceIR.ATEN, name, args[0], args[1], - args[2], - args[3], - args[4], ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] -def aten_ops_relu( +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +def aten_ops_layernorm( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.activation.relu( + return impl.normalization.layer_norm( network, target, SourceIR.ATEN, name, args[0], + args[1], + args[2], + args[3], + args[4], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation.py b/py/torch_tensorrt/dynamo/conversion/impl/activation.py deleted file mode 100644 index 0190768223..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation.py +++ /dev/null @@ -1,63 +0,0 @@ -import math -from typing import Any, Optional, Tuple - -import numpy as np -import torch -from torch import Tensor -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_plugin, - mark_as_int8_layer, - set_layer_name, -) -from torch_tensorrt.fx.converters.impl.activation import * # noqa: F403 -from torch_tensorrt.fx.types import TRTNetwork, TRTPluginFieldCollection, TRTTensor - -import tensorrt as trt - - -def gelu( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_val: TRTTensor, - alpha: Optional[Any] = None, -) -> TRTTensor: - approximate = alpha - if approximate is not None: - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"GELU received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "GeLU converter currently doesn't support implicit batch dimension" - ) - plugin_name = "CustomGeluPluginDynamic" - # type_id 0 for float32, 1 for float16 - type_id = trt.PluginField( - "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 - ) - field_collection = TRTPluginFieldCollection([type_id]) - plugin_version = "1" - - plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) - - layer = network.add_plugin_v2([input_val], plugin) - - def gelu_dyn_range_fn( - dyn_range: Tuple[Tensor, Tensor] - ) -> Tuple[Tensor, Tensor]: # TODO: This probably will not work with fake tensor - return ( - dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))) - ), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))) - - if input_val.dynamic_range is not None: - dyn_range = gelu_dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) - set_layer_name(layer, target, name) - return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py new file mode 100644 index 0000000000..f2157dbdbd --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py @@ -0,0 +1,42 @@ +from typing import Any, Callable, Optional + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + mark_as_int8_layer, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def convert_activation( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.ActivationType, + input_val: TRTTensor, + alpha: Optional[Any] = None, + beta: Optional[Any] = None, + dyn_range_fn: Optional[Callable[[float, float], Any]] = None, +) -> TRTTensor: + """ + Add a TensorRT Activation layer to `network`. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + if alpha is not None: + layer.alpha = alpha + if beta is not None: + layer.beta = beta + set_layer_name(layer, target, name, source_ir) + + if input_val.dynamic_range is not None: + dyn_range = dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py new file mode 100644 index 0000000000..e39e781dd2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -0,0 +1,354 @@ +from typing import Any, Optional + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def relu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.RELU + + def relu_dyn_range_fn(dyn_range): + return max(0, dyn_range[0]), max(0, dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=relu_dyn_range_fn, + ) + + +def sigmoid( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.SIGMOID + + def sigmoid_dyn_range_fn(dyn_range): + def sigmoid_fn(x): + return 1 / (1 + np.exp(-x)) + + return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=sigmoid_dyn_range_fn, + ) + + +def tanh( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.TANH + + def tanh_dyn_range_fn(dyn_range): + def tanh_fn(x): + return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) + + return tanh_fn(dyn_range[0]), tanh_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=tanh_dyn_range_fn, + ) + + +def leaky_relu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any] = 0.01, +): + operation_type = trt.ActivationType.LEAKY_RELU + + def leaky_relu_dyn_range_fn(dyn_range): + def leaky_relu_fn(x): + return max(0, x) + alpha * min(0, x) + + return leaky_relu_fn(dyn_range[0]), leaky_relu_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha, + dyn_range_fn=leaky_relu_dyn_range_fn, + ) + + +def elu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any] = 1.0, + beta: Optional[Any] = None, +): + EPS = 1e-4 + # actually call selu() + if ( + abs(alpha - 1.6732632423543772) < EPS + and beta is not None + and abs(beta - 1.0507009873554805) < EPS + ): + print("Selu is called but re-uses elu function!") + return selu(network, target, source_ir, name, input_val) + + else: + operation_type = trt.ActivationType.ELU + + def elu_dyn_range_fn(dyn_range): + return ( + torch.nn.functional.elu(dyn_range[0], alpha), + torch.nn.functional.elu(dyn_range[1], alpha), + ) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha, + dyn_range_fn=elu_dyn_range_fn, + ) + + +def selu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.SELU + + def selu_dyn_range_fn(dyn_range): + return ( + torch.nn.functional.selu(dyn_range[0]), + torch.nn.functional.selu(dyn_range[1]), + ) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=selu_dyn_range_fn, + ) + + +# no corresponding function in aten/native_functions +def softsign( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.SOFTSIGN + + def softsign_dyn_range_fn(dyn_range): + return ( + torch.nn.functional.softsign(dyn_range[0]), + torch.nn.functional.softsign(dyn_range[1]), + ) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=softsign_dyn_range_fn, + ) + + +def softplus( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + beta: Optional[Any] = 1, +): + operation_type = trt.ActivationType.SOFTPLUS + + def softplus_dyn_range_fn(dyn_range): + return ( + torch.nn.functional.softplus(dyn_range[0], beta), + torch.nn.functional.softplus(dyn_range[1], beta), + ) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha=1 / beta, + beta=beta, + dyn_range_fn=softplus_dyn_range_fn, + ) + + +def clip( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], + beta: Optional[Any], +): + operation_type = trt.ActivationType.CLIP + + def clip_dyn_range_fn(dyn_range): + def clip_fn(x): + return max(alpha, min(beta, x)) + + return clip_fn(dyn_range[0]), clip_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha=alpha, + beta=beta, + dyn_range_fn=clip_dyn_range_fn, + ) + + +def hard_sigmoid( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], + beta: Optional[Any], +): + operation_type = trt.ActivationType.HARD_SIGMOID + + def hard_sigmoid_dyn_range_fn(dyn_range): + def hard_sigmoid_fn(x): + return max(0, min(1, alpha * x + beta)) + + return hard_sigmoid_fn(dyn_range[0]), hard_sigmoid_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha=alpha, + beta=beta, + dyn_range_fn=hard_sigmoid_dyn_range_fn, + ) + + +# no corresponding function in aten/native_functions +def scaled_tanh( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], + beta: Optional[Any], +): + operation_type = trt.ActivationType.SCALED_TANH + + def scaled_tanh_dyn_range_fn(dyn_range): + def scaled_tanh_fn(x): + return alpha * torch.nn.functional.tanh(beta * x) + + return scaled_tanh_fn(dyn_range[0]), scaled_tanh_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha=alpha, + beta=beta, + dyn_range_fn=scaled_tanh_dyn_range_fn, + ) + + +# no corresponding function in aten/native_functions +def thresholded_relu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], +): + operation_type = trt.ActivationType.THRESHOLDED_RELU + + def thresholded_relu_dyn_range_fn(dyn_range): + def thresholded_relu_fn(x): + return x if x > alpha else 0 + + return thresholded_relu_fn(dyn_range[0]), thresholded_relu_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha=alpha, + dyn_range_fn=thresholded_relu_dyn_range_fn, + ) diff --git a/setup.py b/setup.py index fec696bbce..6b013daf9e 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,16 @@ import glob import os import platform +import re import subprocess import sys import warnings from dataclasses import dataclass +from datetime import datetime from distutils.cmd import Command +from pathlib import Path from shutil import copyfile, rmtree +from typing import List import setuptools import yaml @@ -18,15 +22,6 @@ from torch.utils import cpp_extension from wheel.bdist_wheel import bdist_wheel -import yaml -import re -import os -import subprocess - -from datetime import datetime -from pathlib import Path -from typing import List - __version__: str = "0.0.0" __cuda_version__: str = "0.0" __cudnn_version__: str = "0.0" @@ -389,6 +384,7 @@ def run(self): "torch_tensorrt.dynamo.backend", "torch_tensorrt.dynamo.conversion", "torch_tensorrt.dynamo.conversion.impl", + "torch_tensorrt.dynamo.conversion.impl.activation", "torch_tensorrt.dynamo.conversion.impl.condition", "torch_tensorrt.dynamo.conversion.impl.elementwise", "torch_tensorrt.dynamo.conversion.impl.normalization", @@ -415,6 +411,7 @@ def run(self): "torch_tensorrt.dynamo.backend": "py/torch_tensorrt/dynamo/backend", "torch_tensorrt.dynamo.conversion": "py/torch_tensorrt/dynamo/conversion", "torch_tensorrt.dynamo.conversion.impl": "py/torch_tensorrt/dynamo/conversion/impl", + "torch_tensorrt.dynamo.conversion.impl.activation": "py/torch_tensorrt/dynamo/conversion/impl/activation", "torch_tensorrt.dynamo.conversion.impl.condition": "py/torch_tensorrt/dynamo/conversion/impl/condition", "torch_tensorrt.dynamo.conversion.impl.elementwise": "py/torch_tensorrt/dynamo/conversion/impl/elementwise", "torch_tensorrt.dynamo.conversion.impl.normalization": "py/torch_tensorrt/dynamo/conversion/impl/normalization", diff --git a/tests/py/dynamo/conversion/test_clip_aten.py b/tests/py/dynamo/conversion/test_clip_aten.py new file mode 100644 index 0000000000..01e885bc38 --- /dev/null +++ b/tests/py/dynamo/conversion/test_clip_aten.py @@ -0,0 +1,63 @@ +import torch +from .harness import DispatchTestCase +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +class TestClipConverter(DispatchTestCase): + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), + ] + ) + def test_clip(self, test_name, min=None, max=None): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clip(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.clamp.default}) + + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + ] + ) + def test_clip_with_dynamic_shape_four_dimensions( + self, test_name, min=None, max=None + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clip(x, min, max) + + class TestScalarModule(torch.nn.Module): + def forward(self, x): + y = torch.mean(x) + return torch.clip(y, min, max) + + input_specs = [ + Input( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + self.run_test_with_dynamic_shape( + TestScalarModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_gelu_aten.py b/tests/py/dynamo/conversion/test_gelu_aten.py index 24978272eb..e6f234f299 100644 --- a/tests/py/dynamo/conversion/test_gelu_aten.py +++ b/tests/py/dynamo/conversion/test_gelu_aten.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests @@ -6,6 +7,7 @@ from .harness import DispatchTestCase +@pytest.mark.skip(reason="This test will be skipped.") class TestGeLUConverter(DispatchTestCase): def test_gelu(self): class TestModule(nn.Module): diff --git a/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py b/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py new file mode 100644 index 0000000000..2e1f5ddd5b --- /dev/null +++ b/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from .harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +class TestHardSigmoidConverter(DispatchTestCase): + def test_hardsigmoid(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardsigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.hardsigmoid.default} + ) + + def test_hardsigmoid_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardsigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardsigmoid.default} + ) + + def test_hardsigmoid_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardsigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardsigmoid.default} + ) + + def test_hardsigmoid_fp16(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardsigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.hardsigmoid.default}, + precision=torch.half, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_selu_aten.py b/tests/py/dynamo/conversion/test_selu_aten.py index 22057e329f..6b1938c366 100644 --- a/tests/py/dynamo/conversion/test_selu_aten.py +++ b/tests/py/dynamo/conversion/test_selu_aten.py @@ -13,6 +13,8 @@ def forward(self, x): return nn.functional.selu(x) inputs = [torch.randn(1, 10)] + + # Here, selu re-uses elu op self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) def test_selu_with_dynamic_shape(self): diff --git a/tests/py/dynamo/conversion/test_softplus_aten.py b/tests/py/dynamo/conversion/test_softplus_aten.py new file mode 100644 index 0000000000..41c7804ed7 --- /dev/null +++ b/tests/py/dynamo/conversion/test_softplus_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from .harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +class TestSoftplusConverter(DispatchTestCase): + def test_softplus(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softplus(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.softplus.default} + ) + + def test_softplus_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softplus(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.softplus.default} + ) + + def test_softplus_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softplus(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.softplus.default} + ) + + +if __name__ == "__main__": + run_tests()