diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 27af02e5bb..517d33e0c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -123,3 +123,39 @@ def get_shape_with_dynamic_shape( select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) set_layer_name(select_layer, target, f"{name}_select") return select_layer.get_output(0) + + +def to_trt_shape_tensor( + ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] +) -> TRTTensor: + """ + Convert a mixed shape list (ints + ITensors) into a single ITensor. + + Args: + ctx (ConversionContext): TensorRT ConversionContext object. + target (Target): Target of fx node. + name (str): base name for layer naming. + shape_list (list[int | ITensor]): list containing static ints and/or ITensors. + + Returns: + ITensor if shape_list contains any ITensors, else plain Python list of ints. + """ + trt_tensors = [] + + for i, s in enumerate(shape_list): + if isinstance(s, (int, torch.Tensor)): + const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) + set_layer_name(const, target, f"{name}_dim{i}_const") + trt_tensors.append(const.get_output(0)) + else: + trt_tensors.append(s) + + if any(not isinstance(s, int) for s in shape_list): + # Concatenate everything into a single ITensor if there are any ITensors/Tensors + concat_layer = ctx.net.add_concatenation(trt_tensors) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_shape_concat") + return concat_layer.get_output(0) + + # If no ITensor found, return plain list of ints + return shape_list diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4b47ca5dec..55d1bfe0d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,7 +9,10 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.conversion.impl.shape import ( + get_shape_with_dynamic_shape, + to_trt_shape_tensor, +) def upsample( @@ -28,14 +31,20 @@ def upsample( if scale_factor is not None: layer.scales = [1.0, 1.0] + list(scale_factor) else: - shape = list(input.shape)[:2] + list(size) + shape = list(input.shape)[:2] + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = shape + trt_shape = to_trt_shape_tensor(ctx, target, name, shape) + if isinstance(trt_shape, list): + layer.shape = trt_shape + else: + layer.set_input(1, trt_shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST diff --git a/tests/py/dynamo/conversion/test_upsample_aten.py b/tests/py/dynamo/conversion/test_upsample_aten.py index 44c4af2a92..6646cfa63e 100644 --- a/tests/py/dynamo/conversion/test_upsample_aten.py +++ b/tests/py/dynamo/conversion/test_upsample_aten.py @@ -296,6 +296,50 @@ def forward(self, x): ] self.run_test_with_dynamic_shape(TestModule(), input_specs) + @parameterized.expand( + [ + ([torch.tensor(3), 3], None), + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_specs = [ + Input( + min_shape=(1, 1, 1, 1), + opt_shape=(5, 5, 5, 5), + max_shape=(9, 9, 9, 9), + dtype=torch.float32, + ) + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + @parameterized.expand( + [ + # Mix of Tensor and int in output_size + ([torch.tensor(3), 3], None), + # Mix of Tensor and float in scale_factors + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_static_input(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_size = [7, 7] # H, W + inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7] + + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests()