diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 42d6165256..dac526c7e0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,10 +1,13 @@ import logging -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion.converter_utils import ( + is_only_operator_on_placeholder, +) from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from .converter_registry import dynamo_tensorrt_converter @@ -441,29 +444,59 @@ def aten_ops_permute( ) -def to_copy_dtype_validator(to_copy_node: Node) -> bool: - allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} - - # Validate input node has convertible kwargs - if "dtype" in to_copy_node.kwargs: - if to_copy_node.kwargs["dtype"] in allowed_casts: - return True +def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]: + """Return validator for to_copy node with placeholder restrictions""" + + def validate_dtype(to_copy_node: Node) -> bool: + """Returns true if the to_copy node can be converted to TRT + + Based on data type being casted to + """ + allowed_casts = { + torch.float, + torch.int32, + torch.bool, + torch.int8, + torch.float16, + } + + # Validate input node has convertible kwargs + if "dtype" in to_copy_node.kwargs: + if to_copy_node.kwargs["dtype"] in allowed_casts: + return True + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + ) + return False else: _LOGGER.debug( - f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" ) return False - else: - _LOGGER.debug( - f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" + + def validator(to_copy_node: Node) -> bool: + """Returns true if the to_copy node can be converted to TRT + and the placeholder restriction is satisfied + """ + # The placeholder restriction is satsfied if placeholder_only is the same + # truth value as is_only_operator_on_placeholder(to_copy_node) + return validate_dtype(to_copy_node) and ( + (not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node) ) - return False + + return validator @dynamo_tensorrt_converter( - torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator + torch.ops.aten.clone.default, + capability_validator=lambda node: not is_only_operator_on_placeholder(node), ) # type: ignore[misc] -def aten_ops_to_copy_dtype( +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, + capability_validator=to_copy_dtype_validator(placeholder_only=False), +) # type: ignore[misc] +def aten_ops_clone_copy_dtype( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], @@ -476,24 +509,37 @@ def aten_ops_to_copy_dtype( SourceIR.ATEN, name, args[0], - kwargs["dtype"], + kwargs.get("dtype", args[0].dtype), + force_layer=False, ) -@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc] -def aten_ops_clone( +@dynamo_tensorrt_converter( + torch.ops.aten.clone.default, + capability_validator=is_only_operator_on_placeholder, +) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, + capability_validator=to_copy_dtype_validator(placeholder_only=True), +) # type: ignore[misc] +def aten_ops_clone_copy_placeholder( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.cast.clone( + # For clone or copy nodes where the input is also the output, + # we need to force cast to ensure a layer is added to the TRT engine + # since TRT engine inputs cannot also be TRT engine outputs + return impl.cast.to_copy( network, target, SourceIR.ATEN, name, args[0], + kwargs.get("dtype", args[0].dtype), + force_layer=True, ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 1d8dfecf3b..99cf2fa85a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -45,6 +45,20 @@ def get_node_name(node: torch.fx.Node) -> str: return node_name +def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: + """Detects whether a call_function node is the only operator on a placeholder""" + # Returns true if the node operates on a placeholder and is a direct output + return ( + node.op == "call_function" + and any( + arg.op == "placeholder" + for arg in node.args + if isinstance(arg, torch.fx.Node) + ) + and any(user.op == "output" for user in list(node.users.keys())) + ) + + def dynamic_unsupported(node: torch.fx.Node) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( @@ -52,12 +66,17 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: ), "Inputs to validator functions must be FX Nodes" # Check node value itself - if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False): + if ("val" in node.meta) and getattr( + node.meta["val"], "_has_symbolic_sizes_strides", False + ): return False # Check node arguments individually if any( - getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) + ( + ("val" in arg.meta) + and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) + ) for arg in node.args if isinstance(arg, torch.fx.Node) ): @@ -65,7 +84,10 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: # Check node keyword arguments individually if any( - getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) + ( + ("val" in kwarg.meta) + and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) + ) for kwarg in node.kwargs.values() if isinstance(kwarg, torch.fx.Node) ): @@ -82,9 +104,12 @@ def cast_trt_tensor( target: Target = "", source_ir: Optional[SourceIR] = None, ) -> TRTTensor: - """ - Given a TRT Tensor, convert that Tensor to the specified dtype + """Given a TRT Tensor, convert that Tensor to the specified dtype + Adds an Identity layer to the network which performs the conversion + if the input's dtype is different from the cast type. Otherwise returns + input unchanged + Args: network (TRTNetwork): A TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type @@ -191,7 +216,7 @@ def extend_attr_to_tuple( if isinstance(val, tuple): return val else: - raise AssertionError(f"Could not extend attribute {val}") + raise AssertionError(f"Object {val} could not be extended to tuple") def cast_int_or_float_to_bool( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 0c55731169..f31fd9a396 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -3,7 +3,12 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import ( + Frameworks, + unified_dtype_converter, +) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor LOGGER: logging.Logger = logging.getLogger(__name__) @@ -16,28 +21,25 @@ def to_copy( name: str, input: TRTTensor, dtype: TRTDataType, + force_layer: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( f"to_copy received input {input} that is not a TensorRT ITensor" ) - casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) - return casted_tensor - - -def clone( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"clone received input {input} that is not a TensorRT ITensor" - ) - - LOGGER.debug(f"Evaluating clone on object with name: {name}") - - return input + # If cast is forced, insert identity layer regardless of whether the dtype + # doesn't change + if force_layer: + trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + target_str = ConverterRegistry.qualified_name_or_str(target) + target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" + + identity_layer = network.add_identity(input) + identity_layer.set_output_type(0, trt_dtype) + identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]" + return identity_layer.get_output(0) + else: + casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 95dcd88a75..b2176653d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -11,11 +11,7 @@ cast_trt_tensor, get_trt_tensor, ) -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, - set_layer_name, - squeeze_left, -) +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -96,10 +92,10 @@ def convert_binary_elementwise( is_rhs_trt_tensor = False if isinstance(lhs_val, TRTTensor): - lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY) + lhs_dtype = lhs_val.dtype is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY) + rhs_dtype = rhs_val.dtype is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: @@ -124,23 +120,13 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = np.array([rhs_val], dtype=lhs_dtype) + rhs_val = np.array( + [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) + ) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = np.array([lhs_val], dtype=rhs_dtype) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): - rhs_val = squeeze_left(rhs_val) + lhs_val = np.array( + [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) + ) lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index d1893b1c46..50d94713c5 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -35,6 +35,19 @@ def forward(self, x): disable_passes=True, ) + def test_clone_direct(self): + class Clone(nn.Module): + def forward(self, x): + return x.clone() + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + class TestToCopyConverter(DispatchTestCase): def test_to_copy_half(self): @@ -83,6 +96,20 @@ def forward(self, x): disable_passes=True, ) + def test_to_copy_direct(self): + class ToCopyFloat(nn.Module): + def forward(self, x): + return x.to(dtype=torch.float, copy=True) + + inputs = [torch.rand((1, 3, 10)).float()] + self.run_test( + ToCopyFloat(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.float, + disable_passes=True, + ) + if __name__ == "__main__": run_tests()