diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index 4464555261..795b42f879 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -6,7 +6,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) -from torch_tensorrt.dynamo.utils import get_metadata, set_metadata +from torch_tensorrt.dynamo.utils import copy_metadata logger = logging.getLogger(__name__) @@ -26,14 +26,14 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: return replacement_op(input, shape) - # Store metadata of the orig_op - metadata = get_metadata(gm, orig_op) - - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + match_and_replacements = torch.fx.subgraph_rewriter._replace_pattern( + gm, orig, replacement + ) + if match_and_replacements: gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") # Copy the orig_op's metadata to the replacement op - set_metadata(gm, replacement_op, metadata) + copy_metadata(match_and_replacements) return gm diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 95e5f30e4d..85b0163f4b 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -11,8 +11,6 @@ import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor - -from packaging import version from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -22,6 +20,8 @@ from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings +from packaging import version + from .types import TRTDataType logger = logging.getLogger(__name__) @@ -720,6 +720,20 @@ def set_metadata( node.meta = metadata[idx] +def copy_metadata(match_and_replacements: List[Any]) -> None: + """ + Copy the metadata from anchor node to the replacement node. This should be used + if the anchor node is replaced with only a single replacement node i.e one-one replacement. + """ + for match_and_replacement in match_and_replacements: + anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor] + assert ( + len(match_and_replacement.replacements) == 1 + ), "Found more than 1 replacements for the anchor node." + replacement_node = match_and_replacement.replacements[0] + replacement_node.meta = anchor_node.meta + + def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]: ret = [] if isinstance(nodes, torch.fx.node.Node):