diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 68bbcc31d0..e50e57c22c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -16,6 +16,63 @@ ) +def unify_and_concat_trt_tensors( + ctx: ConversionContext, + target: Target, + name: str, + inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]], + concat_axis: int, + cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, + force_trt_output: bool = False, +) -> Union[TRTTensor, List[int]]: + """ + Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic. + + Args: + ctx: TensorRT conversion context. + target: Operation Target. + name: Operation Name. + inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors. + concat_axis: Axis along which to concatenate tensors if dynamic. + cast_dtype: Optional target dtype for casting TRT tensors. + force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations) + """ + has_dynamic = any(not isinstance(x, int) for x in inputs) + trt_tensors = [] + + for i, x in enumerate(inputs): + # convert to TRTTensor + if isinstance(x, TRTTensor): + t = x + elif isinstance(x, int) and not has_dynamic and not force_trt_output: + t = x # pure static path + else: + t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32)) + set_layer_name(t, target, f"{name}_dim{i}_const") + t = t.get_output(0) + + # optional cast + if cast_dtype and isinstance(t, TRTTensor): + t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}") + + trt_tensors.append(t) + + if not has_dynamic and not force_trt_output: + return trt_tensors # all ints + + # promote remaining ints to TRT consts before concat + for i, t in enumerate(trt_tensors): + if isinstance(t, int): + const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) + set_layer_name(const, target, f"{name}_static_{i}_const") + trt_tensors[i] = const.get_output(0) + + concat = ctx.net.add_concatenation(trt_tensors) + concat.axis = concat_axis + set_layer_name(concat, target, f"{name}_concat") + return concat.get_output(0) + + def cat( ctx: ConversionContext, target: Target, @@ -54,9 +111,16 @@ def cat( ) trt_casted_inputs.append(casted_input) trt_inputs = trt_casted_inputs + else: + trt_promoted_type = None - concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(trt_inputs[0].shape)) - concat_layer.axis = dim - set_layer_name(concat_layer, target, f"{name}_gather", source_ir) - return concat_layer.get_output(0) + return unify_and_concat_trt_tensors( + ctx, + target, + name, + trt_inputs, + concat_axis=dim, + cast_dtype=trt_promoted_type, + force_trt_output=True, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4b47ca5dec..ac54e18f3a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,7 +9,12 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.conversion.impl.cat import ( + unify_and_concat_trt_tensors as unify_trt_shape_tensors, +) +from torch_tensorrt.dynamo.conversion.impl.shape import ( + get_shape_with_dynamic_shape, +) def upsample( @@ -28,14 +33,22 @@ def upsample( if scale_factor is not None: layer.scales = [1.0, 1.0] + list(scale_factor) else: - shape = list(input.shape)[:2] + list(size) + shape = list(input.shape)[:2] + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = shape + trt_shape = unify_trt_shape_tensors( + ctx, target, name, shape, concat_axis=0, force_trt_output=False + ) + if isinstance(trt_shape, list): + layer.shape = trt_shape + else: + layer.set_input(1, trt_shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST diff --git a/tests/py/dynamo/conversion/test_upsample_aten.py b/tests/py/dynamo/conversion/test_upsample_aten.py index 44c4af2a92..6646cfa63e 100644 --- a/tests/py/dynamo/conversion/test_upsample_aten.py +++ b/tests/py/dynamo/conversion/test_upsample_aten.py @@ -296,6 +296,50 @@ def forward(self, x): ] self.run_test_with_dynamic_shape(TestModule(), input_specs) + @parameterized.expand( + [ + ([torch.tensor(3), 3], None), + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_specs = [ + Input( + min_shape=(1, 1, 1, 1), + opt_shape=(5, 5, 5, 5), + max_shape=(9, 9, 9, 9), + dtype=torch.float32, + ) + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + @parameterized.expand( + [ + # Mix of Tensor and int in output_size + ([torch.tensor(3), 3], None), + # Mix of Tensor and float in scale_factors + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_static_input(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_size = [7, 7] # H, W + inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7] + + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests()