11import logging
2- from typing import Dict , List , Optional , Sequence , Set
2+ from typing import Dict , List , Optional , Sequence , Set , Tuple
33
44import torch
55
6+ from torch .fx .passes .splitter_base import (
7+ Subgraph ,
8+ _SplitterBase ,
9+ _SplitterSettingBase ,
10+ FxNetAccNodesFinder ,
11+ FxNetAccFusionsFinder ,
12+ )
13+ import torch .fx .passes .operator_support as ops
14+ from torch .fx .passes .tools_common import NodeSet , CALLABLE_NODE_OPS
15+ from torch .fx .node import Target
16+
17+ from torch_tensorrt .dynamo .conversion .converter_registry import ConverterRegistry
618from torch_tensorrt .dynamo .lowering import SUBSTITUTION_REGISTRY
719from torch_tensorrt .dynamo ._defaults import MIN_BLOCK_SIZE
8- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
9- from torch .fx .graph_module import GraphModule
1020from torch .fx .node import _get_qualified_name
11- from torch .fx .passes .operator_support import OperatorSupport
1221
1322from torch_tensorrt .dynamo import DYNAMO_CONVERTERS as CONVERTERS
1423
2130)
2231
2332
24- class TRTPartitioner (CapabilityBasedPartitioner ):
25- """Partitioner to split an FX graph into subgraphs based on operator support
26-
27- Args:
28- graph_module: FX GraphModule to partition
29- operator_support: OperatorSupport class describing allowed operators
30- non_compute_ops: Operators which are not considered computational (e.g. getattr)
31- allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
32- Generally useful for module-level exclusion ops which are intensive despite being single functions
33- min_block_size: Minimum number of computational operators per block
34- Returns:
35- torch.fx.GraphModule
36- """
37-
38- def __init__ (
39- self ,
40- graph_module : GraphModule ,
41- operator_support : OperatorSupport ,
42- * ,
43- non_compute_ops : Optional [Sequence [str ]] = None ,
44- allowed_single_node_partition_ops : Optional [
45- Sequence [str ]
46- ] = DEFAULT_SINGLE_NODE_PARTITIONS ,
47- min_block_size = MIN_BLOCK_SIZE ,
48- ) -> None :
49- super ().__init__ (
50- graph_module ,
51- operator_support ,
52- allows_single_node_partition = True ,
53- non_compute_ops = non_compute_ops ,
54- allowed_single_node_partition_ops = allowed_single_node_partition_ops ,
55- )
56-
57- self .min_block_size = min_block_size
58-
59- def propose_partitions (self ) -> List [Partition ]:
60- # Propose partitions using the default, then refine the results
61- initial_proposed_partitions = super ().propose_partitions ()
62- partitions = {i : part for i , part in enumerate (initial_proposed_partitions )}
63-
64- # For each partition, determine whether or not the number of computational operators
65- # exceeds the threshold, and if not, remove that partition
66- partitions_to_remove = {}
67- for id , partition in partitions .items ():
68- default_non_compute_ops = {"torch.ops.aten.view" , "_operator.getitem" }
69- non_compute_ops = default_non_compute_ops .union (set (self .non_compute_ops ))
70- exempted_partition = False
71-
72- compute_node_count = 0
73- for node in partition .nodes :
74- # Partitions are exempted from min_block_size if they contain an allowed single-node op
75- if (
76- node .op == "call_function"
77- and _get_qualified_name (node .target )
78- in self .allowed_single_node_partition_ops
79- ):
80- exempted_partition = True
81- break
82- elif (
83- node .op == "call_function"
84- and _get_qualified_name (node .target ) not in non_compute_ops
85- ):
86- compute_node_count += 1
87-
88- if compute_node_count < self .min_block_size and not exempted_partition :
89- partitions_to_remove [id ] = compute_node_count
90-
91- # Remove any nodes violating the criteria specified by the user
92- for id , count in partitions_to_remove .items ():
93- logger .debug (
94- f"Removing partition which has { count } < { self .min_block_size } computational operators"
95- )
96- del partitions [id ]
97-
98- return [partitions [k ] for k in sorted (partitions .keys ())]
99-
100- def partition_and_fuse (self ) -> GraphModule :
101- partitions = self .propose_partitions ()
102- fused_gm = self .fuse_partitions (partitions )
103- return fused_gm
104-
105-
106- class TorchTensorRTOperatorSupport (OperatorSupport ):
33+ class OpSupportTester (ops .OperatorSupportBase ):
10734 """Class to determine whether operators within a module are supported"""
10835
109- def __init__ (self , support_dict = None , torch_executed_ops = set ()):
110- super ().__init__ (support_dict )
36+ def __init__ (self , torch_executed_ops : Sequence [ Target ] = set ()) -> None :
37+ super ().__init__ ()
11138
11239 # Initialize sets of supported/unsupported operators
11340 self .supported_operators = {}
@@ -117,11 +44,7 @@ def __init__(self, support_dict=None, torch_executed_ops=set()):
11744 def is_node_supported (
11845 self , submodules : Dict [str , torch .nn .Module ], node : torch .fx .Node
11946 ) -> bool :
120- node_name = (
121- _get_qualified_name (node .target )
122- if not isinstance (node .target , str )
123- else node .target
124- )
47+ node_name = ConverterRegistry .qualified_name_or_str (node .target )
12548
12649 if node in CONVERTERS and node_name not in self .torch_executed_ops :
12750 # If node is a proper, supported computational node, store the operator
@@ -164,11 +87,139 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
16487 logger .debug ("\n All Nodes Supported\n " )
16588
16689
90+ class TRTPartitioner (_SplitterBase ):
91+ """Partitioner to split an FX graph into subgraphs based on operator support
92+
93+ Adapted from, and modified for the Torch-TensorRT Dynamo case:
94+ https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871
95+
96+ Args:
97+ graph_module: FX GraphModule to partition
98+ operator_support: OperatorSupport class describing allowed operators
99+ allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
100+ Generally useful for module-level exclusion ops which are intensive despite being single functions
101+ min_block_size: Minimum number of computational operators per block
102+ Returns:
103+ torch.fx.GraphModule
104+ """
105+
106+ def __init__ (
107+ self ,
108+ graph_module : torch .fx .GraphModule ,
109+ operator_support : ops .OperatorSupportBase ,
110+ allowed_single_node_partition_ops : Optional [
111+ Sequence [str ]
112+ ] = DEFAULT_SINGLE_NODE_PARTITIONS ,
113+ min_block_size : int = MIN_BLOCK_SIZE ,
114+ ):
115+ """
116+ Preprocesses graph before splitting:
117+ - finds nodes supported by ACC,
118+ - finds fusion groups for ACC nodes having non-tensor IO,
119+ - builds a graph of direct dependencies,
120+ - builds a map of fused nodes to their fusions.
121+ As a result we get self.acc_nodes, self.deps and self.fusions.
122+ """
123+ assert isinstance (graph_module , torch .fx .GraphModule )
124+
125+ self .graph_module = graph_module
126+
127+ self .settings = _SplitterSettingBase (
128+ min_acc_module_size = min_block_size , allow_non_tensor = True
129+ )
130+ self .operator_support = operator_support
131+
132+ # Get all accelerated nodes based on operator support conditions
133+ self .acc_nodes = FxNetAccNodesFinder (
134+ self .graph_module , self .operator_support , self .settings .allow_non_tensor
135+ )()
136+
137+ if self .settings .skip_fusion :
138+ self .fusions = {}
139+ else :
140+ self .fusions = FxNetAccFusionsFinder (graph_module , self .acc_nodes )()
141+
142+ # Modify deps to add more deps for fused nodes
143+ self .deps = self .find_deps ()
144+ self .update_deps_for_fusions ()
145+
146+ self .non_acc_submodule_name = "_run_on_gpu_"
147+ self ._node_submodule_map : Dict [str , str ] = {}
148+
149+ self .num_trt_accelerated_subgraphs = None
150+ self .allowed_single_node_partition_ops = allowed_single_node_partition_ops
151+
152+ def remove_small_acc_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
153+ """
154+ This pass finds ACC submodules with less than specified size and merges
155+ them with adjacent GPU submodules.
156+ """
157+ result : List [Subgraph ] = []
158+ for subgraph in subgraphs :
159+ if subgraph .is_acc :
160+ if len (subgraph .nodes ) >= self .settings .min_acc_module_size or any (
161+ ConverterRegistry .qualified_name_or_str (node .target )
162+ in self .allowed_single_node_partition_ops
163+ for node in subgraph .nodes
164+ ):
165+ result .append (subgraph )
166+ else :
167+ logger .debug (
168+ "Eliminating acc subgraph because it's smaller than the threshold: "
169+ f"{ len (subgraph .nodes )} < { self .settings .min_acc_module_size } "
170+ )
171+ if result :
172+ result [- 1 ].nodes .extend (subgraph .nodes )
173+ else :
174+ subgraph .is_acc = False
175+ result .append (subgraph )
176+ else :
177+ if result and not result [- 1 ].is_acc :
178+ result [- 1 ].nodes .extend (subgraph .nodes )
179+ else :
180+ result .append (subgraph )
181+ return result
182+
183+ def partition_graph (self ) -> torch .fx .GraphModule :
184+ """Partitions the GraphModule into subgraphs based on operator support
185+
186+ Returns a GraphModule with submodules for each segment
187+ """
188+ # Delegate nodes based on operator coverage
189+ subgraphs = self .put_nodes_into_subgraphs ()
190+
191+ # Remove segments smaller than the block size (with exceptions)
192+ subgraphs = self .remove_small_acc_subgraphs (subgraphs )
193+
194+ # Set the number of TRT engines to be generated
195+ self .num_trt_accelerated_subgraphs = len ([s for s in subgraphs if s .is_acc ])
196+
197+ # Tag the accelerated nodes and split the graph accordingly
198+ self .tag (subgraphs )
199+ return self .split ()
200+
201+ def starter_nodes (self ) -> Tuple [NodeSet , NodeSet ]:
202+ """Generates starter nodes for partitioning + segmentation"""
203+ # Starter accelerated nodes are all callable accelerated ops
204+ starter_acc_nodes = {
205+ node for node in self .acc_nodes if node .op in CALLABLE_NODE_OPS
206+ }
207+
208+ # Started non-accelerated nodes are the rest of the callable nodes
209+ starter_non_acc_nodes = {
210+ node
211+ for node in self .graph_module .graph .nodes
212+ if (node not in starter_acc_nodes and node .op in CALLABLE_NODE_OPS )
213+ }
214+
215+ return starter_non_acc_nodes , starter_acc_nodes
216+
217+
167218def partition (
168219 gm : torch .fx .GraphModule ,
169220 verbose : bool = True ,
170221 min_block_size : int = MIN_BLOCK_SIZE ,
171- torch_executed_ops : Sequence [str ] = set (),
222+ torch_executed_ops : Sequence [Target ] = set (),
172223) -> torch .fx .GraphModule :
173224 """Partition an FX GraphModule with aten ops into TRT engines
174225 Partitioning is based on converter operator support
@@ -181,18 +232,21 @@ def partition(
181232 Returns:
182233 torch.fx.GraphModule
183234 """
184- supported_ops = TorchTensorRTOperatorSupport (torch_executed_ops = torch_executed_ops )
235+ # Ensure graph is clean prior to partitioning
236+ gm .graph .eliminate_dead_code ()
237+ gm .graph .lint ()
238+ gm .recompile ()
239+
240+ # Construct
241+ supported_ops = OpSupportTester (torch_executed_ops = torch_executed_ops )
185242 partitioner = TRTPartitioner (gm , supported_ops , min_block_size = min_block_size )
186243
187- # Determine partitions based on user specifications and operator support
188- # Then, fuse partitions and display overview of supported/unsupported operators
189- partitions = partitioner .propose_partitions ()
190- fused_graph = partitioner .fuse_partitions (partitions )
244+ partitioned_graph = partitioner .partition_graph ()
191245
192246 if verbose :
193- supported_ops .print_support_overview (len ( partitions ) )
247+ supported_ops .print_support_overview (partitioner . num_trt_accelerated_subgraphs )
194248
195- return fused_graph
249+ return partitioned_graph
196250
197251
198252def get_submod_inputs (
0 commit comments