diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4c75a38c66..20a6acb7ff 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -10,3 +10,4 @@ OPTIMIZATION_LEVEL = None TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False +USE_FAST_PARTITIONER = True diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index e0eef45eb2..4be44cd779 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -10,6 +10,7 @@ PASS_THROUGH_BUILD_FAILURES, PRECISION, TRUNCATE_LONG_AND_DOUBLE, + USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, @@ -29,3 +30,4 @@ class CompilationSettings: optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE + use_fast_partitioner: bool = USE_FAST_PARTITIONER diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ca14ad264b..65f0907464 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -5,13 +5,12 @@ import torch import torch._dynamo as td from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo import CompilationSettings, partitioning from torch_tensorrt.dynamo.conversion import ( convert_module, repair_long_or_double_inputs, ) from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions -from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs @@ -109,24 +108,68 @@ def _compile_module( Returns: Compiled FX GraphModule """ - # Partition module into components that can be TRT-accelerated - partitioned_module = partition( - gm, - verbose=settings.debug, - min_block_size=settings.min_block_size, - torch_executed_ops=settings.torch_executed_ops, + # Check the number of supported operations in the graph + num_supported_ops, total_ops = partitioning.get_graph_converter_support( + gm, settings.debug, settings.torch_executed_ops ) + # If the number of supported operations is 0 or less than the block size, skip the subgraph + # TODO: Add condition to second expression below when require_full_compilation is added + if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size): + logger.warning( + f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " + f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" + ) + return gm + else: + logger.debug( + f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." + ) + + # Partition module into components that can be TRT-accelerated + fast_partitioner_failed = False + + # If specified, try using the fast partitioner and fall back to the global one on failure + if settings.use_fast_partitioner: + try: + partitioned_module = partitioning.fast_partition( + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: + logger.error( + "Partitioning failed on the subgraph with fast partition. See trace above. " + + "Retrying with global partition.", + exc_info=True, + ) + + fast_partitioner_failed = True + settings.use_fast_partitioner = False + + if not settings.use_fast_partitioner: + partitioned_module = partitioning.global_partition( + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): + # Criteria for a module to be convertible to TRT + if settings.use_fast_partitioner and "_run_on_acc" not in name: + continue + submodule = getattr(partitioned_module, name) # Get submodule inputs - submodule_inputs = get_submod_inputs( + submodule_inputs = partitioning.get_submod_inputs( partitioned_module, submodule, sample_inputs ) @@ -151,4 +194,8 @@ def _compile_module( for name, trt_mod in trt_modules.items(): setattr(partitioned_module, name, trt_mod) + # Reset settings object to user specification after fallback to global partitioning mode + if fast_partitioner_failed: + settings.use_fast_partitioner = True + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 0402a6af43..6df1420074 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -5,7 +5,6 @@ import torch import torch_tensorrt from torch.fx.passes.pass_manager import PassManager -from torch.fx.passes.splitter_base import SplitResult from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum EngineCapability, @@ -19,18 +18,17 @@ PASS_THROUGH_BUILD_FAILURES, PRECISION, TRUNCATE_LONG_AND_DOUBLE, + USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, ) from torch_tensorrt.dynamo.backend.backends import _compile_module -from torch_tensorrt.dynamo.conversion import convert_module from torch_tensorrt.dynamo.lowering._fusers import ( fuse_permute_linear, fuse_permute_matmul, ) from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting logger = logging.getLogger(__name__) @@ -62,6 +60,7 @@ def compile( version_compatible: bool = VERSION_COMPATIBLE, optimization_level: Optional[int] = OPTIMIZATION_LEVEL, use_python_runtime: bool = USE_PYTHON_RUNTIME, + use_fast_partitioner: bool = USE_FAST_PARTITIONER, **kwargs: Any, ) -> torch.fx.GraphModule: if debug: @@ -73,7 +72,7 @@ def compile( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures}" + + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -113,55 +112,12 @@ def compile( "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, "truncate_long_and_double": truncate_long_and_double, + "use_fast_partitioner": use_fast_partitioner, } settings = CompilationSettings(**compilation_options) - if kwargs.get("use_capability_partitioner", None): - model = lower_model(gm, torch_inputs) - return _compile_module(model, torch_inputs, settings) - else: - split_result = lower_model_using_trt_splitter(gm, torch_inputs) - trt_module = _compile_graph(split_result, torch_inputs, settings) - - return trt_module - - -def _compile_graph( - split_result: SplitResult, - inputs: Any, - settings: CompilationSettings = CompilationSettings(), - **kwargs: Any, -) -> torch.fx.GraphModule: - for submod_name, submod_inputs in split_result.submodule_inputs.items(): - submod = getattr(split_result.split_module, submod_name) - # Only acc submodules will be lowered. - if not submod_name.startswith(split_result.non_acc_submodule_prefix): - # Create TRT Module from submodule - trt_mod = convert_module( - submod, - submod_inputs, - settings=settings, - name=submod_name, - ) - setattr(split_result.split_module, submod_name, trt_mod) - - return split_result.split_module - - -def lower_model_using_trt_splitter( - model: torch.nn.Module, inputs: Any, **kwargs: Any -) -> SplitResult: - # Perform basic lowering - model = lower_model(model, inputs) - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter_setting.min_acc_module_size = 1 - splitter_setting.use_experimental_rt = False - splitter = TRTSplitter(model, inputs, settings=splitter_setting) - splitter.node_support_preview() - split_result = splitter.generate_split_results() - return split_result + return _compile_module(gm, torch_inputs, settings) def lower_model( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 7275844500..47e8621ac3 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]: """Returns the set of unique converter targets stored across all registries""" return set.union(*[set(registry.keys()) for registry in self.registries]) - # TODO: Make this a static method since it does not need state - def qualified_name_or_str(self, target: Target) -> str: + @staticmethod + def qualified_name_or_str(target: Target) -> str: """Returns string representation of an FX Node target""" if isinstance(target, str): return target diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 0e13125fc6..6eda61a6fd 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,10 +1,5 @@ from ._decompositions import get_decompositions # noqa: F401 -from ._fusers import * # noqa: F403 -from ._partition import ( # noqa: F401 - DEFAULT_SINGLE_NODE_PARTITIONS, - get_submod_inputs, - partition, -) +from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 -from .substitutions import * # noqa: F403 +from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py new file mode 100644 index 0000000000..1f9d11b14b --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -0,0 +1,3 @@ +from ._adjacency_partitioner import partition as fast_partition +from ._global_partitioner import partition as global_partition +from .common import get_graph_converter_support, get_submod_inputs diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py new file mode 100644 index 0000000000..481c851916 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -0,0 +1,246 @@ +import logging +from typing import Collection, Dict, List, Optional, Tuple + +import torch +import torch.fx.passes.operator_support as ops +from torch.fx.node import Target +from torch.fx.passes.splitter_base import ( + FxNetAccFusionsFinder, + FxNetAccNodesFinder, + Subgraph, + _SplitterBase, + _SplitterSettingBase, +) +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet +from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE +from torch_tensorrt.dynamo.conversion.converter_registry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry + +from .common import DEFAULT_SINGLE_NODE_PARTITIONS + +logger = logging.getLogger(__name__) + + +class OpSupportTester(ops.OperatorSupportBase): # type: ignore + """Class to determine whether operators within a module are supported""" + + def __init__(self, torch_executed_ops: Collection[Target] = set()) -> None: + super().__init__() + + # Initialize sets of supported/unsupported operators + self.supported_operators: Dict[str, int] = {} + self.unsupported_operators: Dict[str, int] = {} + self.torch_executed_ops = torch_executed_ops + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + node_name = ConverterRegistry.qualified_name_or_str(node.target) + + if node in CONVERTERS and node_name not in self.torch_executed_ops: + # If node is a proper, supported computational node, store the operator + if not node.is_impure(): + if node_name not in self.supported_operators: + self.supported_operators[node_name] = 1 + else: + self.supported_operators[node_name] += 1 + + return True + else: + if not node.is_impure(): + if node_name not in self.unsupported_operators: + self.unsupported_operators[node_name] = 1 + else: + self.unsupported_operators[node_name] += 1 + + return False + + def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None: + if num_trt_blocks is not None: + logger.debug( + f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" + ) + + # Reformat support messages for debugger to print node overview as a single string + supported_nodes_str = "\nSupported Nodes:\n" + for node_name, count in self.supported_operators.items(): + supported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + + logger.debug(supported_nodes_str) + + if self.unsupported_operators: + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" + for node_name, count in self.unsupported_operators.items(): + unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + + logger.debug(unsupported_nodes_str) + else: + logger.debug("\nAll Nodes Supported\n") + + +class TRTPartitioner(_SplitterBase): # type: ignore + """Partitioner to split an FX graph into subgraphs based on operator support + + Adapted from, and modified for the Torch-TensorRT Dynamo case: + https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871 + + Args: + module: FX GraphModule to partition + operator_support: OperatorSupport class describing allowed operators + allowed_single_node_partition_ops: Nodes which can be included in single-node partitons. + Generally useful for module-level exclusion ops which are intensive despite being single functions + min_block_size: Minimum number of computational operators per block + Returns: + torch.fx.GraphModule + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: ops.OperatorSupportBase, + allowed_single_node_partition_ops: Optional[ + Collection[str] + ] = DEFAULT_SINGLE_NODE_PARTITIONS, + min_block_size: int = MIN_BLOCK_SIZE, + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + + self.settings = _SplitterSettingBase( + min_acc_module_size=min_block_size, + allow_non_tensor=True, + ) + self.operator_support = operator_support + + # Get all accelerated nodes based on operator support conditions + self.acc_nodes = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + )() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = "_run_on_gpu_" + self._node_submodule_map: Dict[str, str] = {} + + self.num_trt_accelerated_subgraphs: Optional[int] = None + self.allowed_single_node_partition_ops = allowed_single_node_partition_ops + + def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent GPU submodules. + """ + result: List[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size or ( + self.allowed_single_node_partition_ops is not None + and any( + ConverterRegistry.qualified_name_or_str(node.target) + in self.allowed_single_node_partition_ops + for node in subgraph.nodes + ) + ): + result.append(subgraph) + else: + logger.debug( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def partition_graph(self) -> torch.fx.GraphModule: + """Partitions the GraphModule into subgraphs based on operator support + + Returns a GraphModule with submodules for each segment + """ + # Delegate nodes based on operator coverage + subgraphs = self.put_nodes_into_subgraphs() + + # Remove segments smaller than the block size (with exceptions) + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + + # Set the number of TRT engines to be generated + self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) + + # Tag the accelerated nodes and split the graph accordingly + self.tag(subgraphs) + return self.split() + + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: + """Generates starter nodes for partitioning + segmentation""" + # Starter accelerated nodes are all callable accelerated ops + starter_acc_nodes = { + node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS + } + + # Started non-accelerated nodes are the rest of the callable nodes + starter_non_acc_nodes = { + node + for node in self.module.graph.nodes + if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS) + } + + return starter_non_acc_nodes, starter_acc_nodes + + +def partition( + gm: torch.fx.GraphModule, + verbose: bool = DEBUG, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Collection[Target] = set(), +) -> torch.fx.GraphModule: + """Partition an FX GraphModule with aten ops into TRT engines + Partitioning is based on converter operator support + + Args: + gm: FX GraphModule to partition + verbose: Bool representing whether to print operator support + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage + Returns: + torch.fx.GraphModule + """ + # Ensure graph is clean prior to partitioning + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + # Construct + supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops) + partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) + + partitioned_graph = partitioner.partition_graph() + + if verbose: + supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) + + return partitioned_graph diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py similarity index 74% rename from py/torch_tensorrt/dynamo/lowering/_partition.py rename to py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 246549461a..500b6bb9fe 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,23 +1,19 @@ import logging -from typing import Dict, List, Mapping, Optional, Sequence, Set +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set import torch from torch.fx.graph_module import GraphModule -from torch.fx.node import _get_qualified_name from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupport, SupportDict -from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE +from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE from torch_tensorrt.dynamo.conversion.converter_registry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry -logger = logging.getLogger(__name__) +from .common import DEFAULT_SINGLE_NODE_PARTITIONS -DEFAULT_SINGLE_NODE_PARTITIONS: List[str] = [ - _get_qualified_name(to_replace.new_operator) - for to_replace in SUBSTITUTION_REGISTRY.values() -] +logger = logging.getLogger(__name__) class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc] @@ -41,7 +37,7 @@ def __init__( *, non_compute_ops: Optional[Sequence[str]] = None, allowed_single_node_partition_ops: Optional[ - Sequence[str] + Collection[str] ] = DEFAULT_SINGLE_NODE_PARTITIONS, min_block_size: int = MIN_BLOCK_SIZE, ) -> None: @@ -73,14 +69,15 @@ def propose_partitions(self) -> List[Partition]: # Partitions are exempted from min_block_size if they contain an allowed single-node op if ( node.op == "call_function" - and _get_qualified_name(node.target) + and ConverterRegistry.qualified_name_or_str(node.target) in self.allowed_single_node_partition_ops ): exempted_partition = True break elif ( node.op == "call_function" - and _get_qualified_name(node.target) not in non_compute_ops + and ConverterRegistry.qualified_name_or_str(node.target) + not in non_compute_ops ): compute_node_count += 1 @@ -122,11 +119,7 @@ def __init__( def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - node_name = ( - _get_qualified_name(node.target) - if not isinstance(node.target, str) - else node.target - ) + node_name = ConverterRegistry.qualified_name_or_str(node.target) if node in CONVERTERS and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator @@ -146,32 +139,37 @@ def is_node_supported( return False - def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None: + def print_support_overview( + self, num_trt_blocks: Optional[int] = None, print_node_support: bool = False + ) -> None: if num_trt_blocks is not None: logger.debug( f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" ) - # Reformat support messages for debugger to print node overview as a single string - supported_nodes_str = "\nSupported Nodes:\n" - for node_name, count in self.supported_operators.items(): - supported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + if print_node_support: + # Reformat support messages for debugger to print node overview as a single string + supported_nodes_str = "\nSupported Nodes:\n" + for node_name, count in self.supported_operators.items(): + supported_nodes_str += f"- {node_name} + Operator Count: {count}\n" - logger.debug(supported_nodes_str) + logger.debug(supported_nodes_str) - if self.unsupported_operators: - unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" - for node_name, count in self.unsupported_operators.items(): - unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + if self.unsupported_operators: + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" + for node_name, count in self.unsupported_operators.items(): + unsupported_nodes_str += ( + f"- {node_name} + Operator Count: {count}\n" + ) - logger.debug(unsupported_nodes_str) - else: - logger.debug("\nAll Nodes Supported\n") + logger.debug(unsupported_nodes_str) + else: + logger.debug("\nAll Nodes Supported\n") def partition( gm: torch.fx.GraphModule, - verbose: bool = True, + verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, ) -> torch.fx.GraphModule: @@ -202,29 +200,3 @@ def partition( supported_ops.print_support_overview(len(partitions)) return fused_graph - - -def get_submod_inputs( - mod: torch.fx.GraphModule, - submod: torch.fx.GraphModule, - inputs: Sequence[torch.Tensor], -) -> Optional[Sequence[torch.Tensor]]: - """Helper function to get inputs to a Torch submodule - - Args: - mod: Parent FX GraphModule - submod: Child FX GraphModule - inputs: Sample inputs to parent module - Returns: - Sequence of Tensors representing inputs to child module - """ - acc_inputs: Optional[Sequence[torch.Tensor]] = None - - def get_input(_: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]) -> None: - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py new file mode 100644 index 0000000000..8c36668d00 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -0,0 +1,76 @@ +from typing import Any, Optional, Sequence, Set, Tuple + +import torch +from torch.fx.node import _get_qualified_name +from torch_tensorrt.dynamo._defaults import DEBUG +from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY + +DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = { + _get_qualified_name(to_replace.new_operator) + for to_replace in SUBSTITUTION_REGISTRY.values() +} + + +def get_submod_inputs( + mod: torch.fx.GraphModule, + submod: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], +) -> Optional[Sequence[torch.Tensor]]: + """Helper function to get inputs to a Torch submodule + + Args: + mod: Parent FX GraphModule + submod: Child FX GraphModule + inputs: Sample inputs to parent module + Returns: + Sequence of Tensors representing inputs to child module + """ + acc_inputs = None + + def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None: + nonlocal acc_inputs + acc_inputs = inputs + return + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + +def get_graph_converter_support( + graph_module: torch.fx.GraphModule, + verbose: bool = DEBUG, + torch_executed_ops: Optional[Set[str]] = None, +) -> Tuple[int, int]: + """Helper function to get converter support overview pre-partitioning + + Args: + graph_module: FX GraphModule to determine support for + verbose: Bool representing whether to print operator support + torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage + Returns: + The number of supported call_function nodes in the graph + """ + from ._global_partitioner import TorchTensorRTOperatorSupport + + # Instantiate operator support object and module dictionary + op_support = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) + module_dict = dict(graph_module.named_modules()) + + number_of_supported_nodes = 0 + total_functional_nodes = 0 + + # Iterate over all nodes in the graph, enumerating call_function nodes + for node in graph_module.graph.nodes: + if node.op == "call_function": + total_functional_nodes += 1 + + if op_support.is_node_supported(module_dict, node): + number_of_supported_nodes += 1 + + # Print node support overview prior to partitioning + if verbose: + op_support.print_support_overview(print_node_support=True) + + return number_of_supported_nodes, total_functional_nodes diff --git a/setup.py b/setup.py index f9329cce7e..0964cc53fe 100644 --- a/setup.py +++ b/setup.py @@ -348,6 +348,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.substitutions", + "torch_tensorrt.dynamo.partitioning", "torch_tensorrt.dynamo.runtime", "torch_tensorrt.dynamo.tools", "torch_tensorrt.fx", @@ -373,6 +374,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions", + "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime", "torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools", "torch_tensorrt.fx": "py/torch_tensorrt/fx", diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 9c0cf18d25..043aa72276 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,6 +1,6 @@ import torch import torch_tensorrt -from torch_tensorrt.dynamo.lowering import partition +from torch_tensorrt.dynamo.partitioning import fast_partition from torch.testing._internal.common_utils import run_tests, TestCase from copy import deepcopy from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT @@ -17,10 +17,16 @@ def forward(self, x, y): return torch.mean(out, dim=1) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( - len(list(partitioned_graph.named_children())), + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), 1, "All operators are supported, there should be one segment", ) @@ -98,7 +104,13 @@ def forward(self, x, y): "Without control flow breaks, there should only be a single graph", ) self.assertEquals( - len(list(partitioned_graphs[0].named_children())), + len( + [ + 1 + for submod in list(partitioned_graphs[0].named_children()) + if "_run_on_acc" in submod[0] + ] + ), 2, "Certain operators are set to run in Torch, expected 2 segments", ) @@ -184,7 +196,7 @@ def forward(self, x, y): ) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len(list(partitioned_graph.named_children())), @@ -261,7 +273,13 @@ def forward(self, x, y): "Without control flow breaks, there should only be a single graph", ) self.assertEquals( - len(list(partitioned_graphs[0].named_children())), + len( + [ + 1 + for submod in list(partitioned_graphs[0].named_children()) + if "_run_on_acc" in submod[0] + ] + ), 1, "Certain operators are set to run in Torch, expected 1 segment", ) diff --git a/tests/py/dynamo/backend/test_decompositions.py b/tests/py/dynamo/backend/test_decompositions.py index 0e11bfd2b1..e7cd7ac589 100644 --- a/tests/py/dynamo/backend/test_decompositions.py +++ b/tests/py/dynamo/backend/test_decompositions.py @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x): - y = torch.ops.aten.alias.default(x) + y = torch.ops.aten.alias.default(x) + 1 return y # Operations expected to be removed in the traced graph after decompositions diff --git a/tests/py/dynamo/backend/test_partitioning.py b/tests/py/dynamo/backend/test_partitioning.py index a5d6495754..c416c55e5c 100644 --- a/tests/py/dynamo/backend/test_partitioning.py +++ b/tests/py/dynamo/backend/test_partitioning.py @@ -1,12 +1,13 @@ -from torch_tensorrt.dynamo.lowering import partition -from torch.testing._internal.common_utils import run_tests, TestCase -from utils import lower_graph_testing -import torch from copy import deepcopy + import numpy as np +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from utils import lower_graph_testing -class TestPartitioning(TestCase): +class TestFastPartitioning(TestCase): def test_partition_fully_supported_one_op(self): class FullySupportedOneOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -16,7 +17,154 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partition(deepcopy(fx_graph)) + partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph)) + self.assertEquals( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 0, + "Single operators should not be segmented", + ) + + def test_partition_fully_supported_multi_op(self): + class FullySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_ = torch.ops.aten.sub.Tensor(x, y) + concat_ = torch.ops.aten.cat.default(x, sum_) + relu_ = torch.ops.aten.relu.default(concat_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partitioning.fast_partition( + deepcopy(fx_graph), min_block_size=2 + ) + self.assertEquals( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 1, + "All operators are supported, there should be one segment", + ) + + def test_partition_partially_supported_multi_op(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = np.sum(sum_1) + np.sum(sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph = partitioning.fast_partition( + deepcopy(fx_graph), min_block_size=2 + ) + self.assertEquals( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 2, + "Unsupported operators interleave supported ones, expected 2 segments", + ) + + def test_partition_partially_supported_with_torch_executed_ops(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + ( + unexpected_ops_seen, + _, + partitioned_graphs, + ) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=2, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + use_fast_partitioner=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len( + [ + 1 + for submod in list(partitioned_graphs[0].named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + + +class TestGlobalPartitioning(TestCase): + def test_partition_fully_supported_one_op(self): + class FullySupportedOneOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) + partitioned_graph = partitioning.global_partition(deepcopy(fx_graph)) self.assertEquals( len(list(partitioned_graph.named_children())), 0, @@ -36,7 +184,9 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) + partitioned_graph = partitioning.global_partition( + deepcopy(fx_graph), min_block_size=2 + ) self.assertEquals( len(list(partitioned_graph.named_children())), 1, @@ -57,7 +207,9 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) + partitioned_graph = partitioning.global_partition( + deepcopy(fx_graph), min_block_size=2 + ) self.assertEquals( len(list(partitioned_graph.named_children())), 2, @@ -104,6 +256,7 @@ def forward(self, x, y): min_block_size=2, torch_executed_ops={"torch.ops.aten.add.Tensor"}, testing_partitioning=True, + use_fast_partitioner=False, ) self.assertEquals( diff --git a/tests/py/dynamo/backend/utils.py b/tests/py/dynamo/backend/utils.py index 0eaba4aeea..4a5466d3a2 100644 --- a/tests/py/dynamo/backend/utils.py +++ b/tests/py/dynamo/backend/utils.py @@ -5,9 +5,7 @@ from torch_tensorrt.dynamo.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.lowering._partition import ( - partition, -) +from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo.lowering._pre_aot_lowering import ( pre_aot_substitutions, ) @@ -27,6 +25,7 @@ def fx_dynamo_testing_backend( store_intermediate_graphs: List, min_block_size: int = 3, torch_executed_ops: Sequence[str] = set(), + use_fast_partitioner: bool = True, ): """Helper Dynamo backend exclusively for testing""" custom_backend = partial( @@ -34,6 +33,7 @@ def fx_dynamo_testing_backend( store_intermediate_graphs=store_intermediate_graphs, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + use_fast_partitioner=use_fast_partitioner, ) gm = pre_aot_substitutions(gm) @@ -54,11 +54,21 @@ def compile_module_testing( store_intermediate_graphs: List, min_block_size: int = 3, torch_executed_ops: Sequence[str] = str(), + use_fast_partitioner: bool = True, ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" - partitioned_module = partition( - gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops - ) + if use_fast_partitioner: + partitioned_module = partitioning.fast_partition( + gm, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + ) + else: + partitioned_module = partitioning.global_partition( + gm, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + ) # Store intermediate graph from partitioned module store_intermediate_graphs.append(deepcopy(partitioned_module)) @@ -130,6 +140,7 @@ def lower_graph_testing( min_block_size: int = 3, torch_executed_ops: Sequence[str] = set(), testing_partitioning: bool = False, + use_fast_partitioner: bool = True, ): """Helper function to assist with graph lowering for testing of Dynamo compile @@ -141,6 +152,7 @@ def lower_graph_testing( min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage testing_partitioning: Whether partitioning is being tested (to analyze only TRT-supported ops) + use_fast_partitioner: Whether to use the fast or global partitioner Returns: If testing_partitioning: List[torch.fx.GraphModule], Set, Set: List of partitioned graph outputs, unexpected ops seen, expected ops unseen @@ -154,6 +166,7 @@ def lower_graph_testing( store_intermediate_graphs=partitioned_graphs, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + use_fast_partitioner=use_fast_partitioner, ) # Invoke compilation @@ -176,7 +189,11 @@ def classify_node(node: torch.fx.Node): for top_level_node in fx_module.graph.nodes: if top_level_node.op == "call_function" and not testing_partitioning: classify_node(top_level_node) - elif top_level_node.op == "call_module": + elif top_level_node.op == "call_module" and ( + not testing_partitioning + or not use_fast_partitioner + or ("_run_on_acc_" in top_level_node.target) + ): for node in fx_module.get_submodule(top_level_node.target).graph.nodes: classify_node(node)