diff --git a/examples/distributed_inference/data_parallel_stable_diffusion.py b/examples/distributed_inference/data_parallel_stable_diffusion.py index 5c0e3113e5..023d7e8e63 100644 --- a/examples/distributed_inference/data_parallel_stable_diffusion.py +++ b/examples/distributed_inference/data_parallel_stable_diffusion.py @@ -53,7 +53,5 @@ # Assume there are 2 processes (2 devices) with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: - print("before \n") result = pipe(prompt).images[0] - print("after ") result.save(f"result_{distributed_state.process_index}.png") diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 164f0c1065..c6e07d481f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2,7 +2,7 @@ import logging import operator -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -217,7 +217,52 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True) +def parse_cat_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Any] +) -> Tuple[List[Any], int]: + """ + Process inputs for torch.ops.aten.cat.default. + + Handles these valid patterns: + 1. args = ((t1, t2, ...), dim) + 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs + + Returns: + (input_tensors, dim) + input_tensors: tuple of tensor arguments + dim: integer concatenation dimension (default 0) + """ + + if len(args) > 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + dim = args_bounds_check(args, 1, 0) + + else: + # If single arg is itself a tuple/list, unwrap it + if len(args) == 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + else: + input_tensors = list(args) + + dim = kwargs.get("dim", 0) + + return input_tensors, dim + + +def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. + inputs, _ = parse_cat_args(node.args, node.kwargs) + for each_input in inputs: + if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.cat.default, + supports_dynamic_shapes=True, + capability_validator=cat_validator, +) def aten_ops_cat( ctx: ConversionContext, target: Target, @@ -225,13 +270,14 @@ def aten_ops_cat( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + inputs, dim = parse_cat_args(args, kwargs) return impl.cat.cat( ctx, target, SourceIR.ATEN, name, - input=args[0], - dim=args_bounds_check(args, 1, 0), + input=inputs, + dim=dim, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 68bbcc31d0..e3dd01477f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, Sequence, Union import numpy as np @@ -15,6 +16,8 @@ set_layer_name, ) +logger = logging.getLogger(__name__) + def cat( ctx: ConversionContext, @@ -27,6 +30,13 @@ def cat( ) -> Union[TRTTensor, Sequence[TRTTensor]]: trt_inputs = [] for i, each_input in enumerate(input): + if isinstance(each_input, torch.Tensor) and each_input.numel() == 0: + logger.warning( + f"Warning: empty tensor in cat input {i}, replacing with zeros" + ) + # ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. + # hence the validator + continue if not isinstance(each_input, TRTTensor): each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") if cast_dtype: diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index a9e4a45c81..15aa8b0d80 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -25,6 +25,77 @@ def forward(self, x, y, z): inputs, ) + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat_dim_in_kwargs(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.ops.aten.cat.default((x, y, z), dim=dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + ) + + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_scalar_inputs(self, _, dim): + # Ensure scalar tensor wrap works + class Cat(nn.Module): + def forward(self, x, y): + # y is a scalar, x is a tensor + return torch.ops.aten.cat.default((x, y), dim) + + x = torch.randn(1, 2, 3, device="cuda") + y = torch.ones_like(x) * 5.0 # simulate scalar broadcast + inputs = [x, y] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_empty_tensor(self, _, dim): + # Handle empty tensor in concat + class Cat(nn.Module): + def forward(self, x): + y = torch.empty(0, 2, 3, device="cuda") + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.randn(1, 2, 3, device="cuda"), + ] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 2), + ("neg", -1), + ] + ) + def test_cat_with_different_dtypes(self, _, dim): + # check dtype promotion path in concat + class Cat(nn.Module): + def forward(self, x, y): + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"), + torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"), + ] + self.run_test(Cat(), inputs) + @parameterized.expand( [ ("pos", 1),