1515"""
1616
1717import logging
18- from itertools import count
1918from typing import Callable , List , Optional , Sequence , Tuple
2019
2120import torch
3635)
3736from executorch .exir .backend .utils import tag_constant_data , WhyNoPartitionReporter
3837from executorch .exir .dialects ._ops import ops as exir_ops
39- from executorch .exir .graph_module import get_control_flow_submodules
4038from torch .export .exported_program import ExportedProgram
41- from torch .fx import GraphModule
42- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
39+ from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
4340from torch .fx .passes .operator_support import OperatorSupportBase
4441
4542logger = logging .getLogger (__name__ )
@@ -113,43 +110,6 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
113110 return all (m == 1 for m in multiples )
114111
115112
116- def is_partitioned (
117- node : torch .fx .Node ,
118- tag : str ,
119- ) -> bool :
120- """Return True if the node currently belongs to the partition ``tag``.
121-
122- Args:
123- node (torch.fx.Node): FX node to check.
124- tag (str): Delegation tag identifying the partition.
125-
126- Returns:
127- bool: True if the node carries the matching delegation tag.
128-
129- """
130- return "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
131-
132-
133- def reject_partition (
134- reason : str , partition : Partition , reporter : WhyNoPartitionReporter
135- ) -> None :
136- """Remove a proposed partition and record the rejection reason.
137-
138- Args:
139- reason (str): Human-readable explanation for rejection.
140- partition (object): Proposed partition object from the
141- capability partitioner.
142- reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
143- """
144- for node in partition .nodes :
145- if "delegation_tag" in node .meta :
146- del node .meta ["delegation_tag" ]
147- reporter .report_reject (
148- node ,
149- reason ,
150- )
151-
152-
153113class TOSAPartitioner (Partitioner ):
154114 """Partition an exported program into TOSA-delegable subgraphs.
155115
@@ -182,76 +142,107 @@ def __init__(
182142 self .additional_checks = additional_checks
183143 self .tosa_spec = compile_spec .tosa_spec
184144
185- def _tag_module ( # noqa
186- self ,
187- module : GraphModule ,
188- containing_program : ExportedProgram ,
189- reporter : WhyNoPartitionReporter ,
190- tag_iterator : count | None = None ,
191- ) -> set [str ]:
192- """Tag nodes in a module, possibly a submodule, from the containing program.
145+ def partition (self , exported_program : ExportedProgram ) -> PartitionResult : # noqa
146+ """Partition the program and tag TOSA-compatible subgraphs.
147+
148+ Run the FX capability-based partitioner to propose subgraphs, then
149+ refine tags by removing boundary-only quantize/dequantize nodes and by
150+ rejecting partitions that would lower to no-ops. Emit a detailed report
151+ of rejected nodes and their reasons.
193152
194153 Args:
195- module: a GraphModule from `containing_program` to tag nodes in.
196- containing_program: The ExportedProgram that contains the module .
197- reporter: A reporter to report why nodes were rejected.
154+ exported_program (ExportedProgram): Program to analyze and
155+ partition .
156+
198157 Returns:
199- A set of strings with the partition tags.
158+ PartitionResult: The input program with nodes tagged for delegation
159+ and a mapping of partition tags to delegation specs.
160+
200161 """
201- tags : set [str ] = set ()
202- if tag_iterator is None :
203- tag_iterator = count (0 )
204- for _ , submodule , _ in get_control_flow_submodules (module ):
205- submodule_tags = self ._tag_module (
206- submodule , containing_program , reporter , tag_iterator
207- )
208- if len (tags & submodule_tags ) != 0 :
209- raise RuntimeError (
210- "Got overlapping tags in two different modules, this shouldn't happen."
211- )
212- tags = tags | submodule_tags
162+ logger .info ("TOSAPartitioner::partition" )
163+ partition_tags : dict [str , DelegationSpec ] = {}
164+
165+ logger .info (
166+ f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
167+ )
168+
169+ reporter = WhyNoPartitionReporter ()
213170 operator_support = tosa_support_factory (
214- self .tosa_spec , containing_program , reporter , self .additional_checks
171+ self .tosa_spec , exported_program , reporter , self .additional_checks
215172 )
216173 capability_partitioner = CapabilityBasedPartitioner (
217- module ,
174+ exported_program . graph_module ,
218175 operator_support ,
219176 allows_single_node_partition = True ,
220177 )
221178 partition_list = capability_partitioner .propose_partitions ()
222179
180+ def reject_partition (reason : str , partition , tag ) -> None :
181+ """Remove a proposed partition and record the rejection reason.
182+
183+ Args:
184+ reason (str): Human-readable explanation for rejection.
185+ partition (object): Proposed partition object from the
186+ capability partitioner.
187+ tag (str): Delegation tag associated with the partition.
188+
189+ """
190+ for node in partition .nodes :
191+ if "delegation_tag" in node .meta :
192+ del node .meta ["delegation_tag" ]
193+ reporter .report_reject (
194+ node ,
195+ reason ,
196+ )
197+ partition_tags .pop (tag , None )
198+
223199 for partition in partition_list :
224- tag = f"tag{ next (tag_iterator )} "
225- tags .add (tag )
200+ tag = f"tag{ partition .id } "
201+
202+ def is_partitioned (node : torch .fx .Node , tag = tag ) -> bool :
203+ """Return True if the node currently belongs to the partition ``tag``.
204+
205+ Args:
206+ node (torch.fx.Node): FX node to check.
207+ tag (str): Delegation tag identifying the partition.
208+
209+ Returns:
210+ bool: True if the node carries the matching delegation tag.
211+
212+ """
213+ return (
214+ "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
215+ )
226216
227217 for node in partition .nodes :
228218 node .meta ["delegation_tag" ] = tag
219+ partition_tags [tag ] = self .delegation_spec
229220
230221 # De-tag outermost q-nodes upwards and dq-nodes downwards.
231222 # De-tag if at least one input/output is not part of the partition.
232- for node in module .graph .nodes :
233- if not is_partitioned (node , tag ):
223+ for node in exported_program . graph_module .graph .nodes :
224+ if not is_partitioned (node ):
234225 continue
235226 if node .target in Q_OPS :
236227 for input in node .all_input_nodes :
237- if not is_partitioned (input , tag ):
228+ if not is_partitioned (input ):
238229 del node .meta ["delegation_tag" ]
239230 break
240231 continue
241232
242233 if node .target in DQ_OPS :
243234 for user in node .users :
244- if not is_partitioned (user , tag ):
235+ if not is_partitioned (user ):
245236 del node .meta ["delegation_tag" ]
246237 break
247238 continue
248239
249240 if self .tosa_spec .support_float ():
250241 continue
251242
252- if is_partitioned (node , tag ):
243+ if is_partitioned (node ):
253244 for input in node .all_input_nodes :
254- if is_partitioned (input , tag ):
245+ if is_partitioned (input ):
255246 continue
256247 if get_first_fake_tensor (input ).dtype .is_floating_point :
257248 reporter .report_reject (
@@ -274,38 +265,8 @@ def _tag_module( # noqa
274265 reject_partition (
275266 "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition." ,
276267 partition ,
277- reporter ,
268+ tag ,
278269 )
279- tags .remove (tag )
280- return tags
281-
282- def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
283- """Partition the program and tag TOSA-compatible subgraphs.
284-
285- Run the FX capability-based partitioner to propose subgraphs, then
286- refine tags by removing boundary-only quantize/dequantize nodes and by
287- rejecting partitions that would lower to no-ops. Emit a detailed report
288- of rejected nodes and their reasons.
289-
290- Args:
291- exported_program (ExportedProgram): Program to analyze and
292- partition.
293-
294- Returns:
295- PartitionResult: The input program with nodes tagged for delegation
296- and a mapping of partition tags to delegation specs.
297-
298- """
299- logger .info ("TOSAPartitioner::partition" )
300- logger .info (
301- f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
302- )
303-
304- reporter = WhyNoPartitionReporter ()
305- tags = self ._tag_module (
306- exported_program .graph_module , exported_program , reporter
307- )
308- partition_tags = {tag : self .delegation_spec for tag in tags }
309270
310271 tag_constant_data (exported_program )
311272 logger .info (f"The following nodes were rejected for { self .tosa_spec } :" )
0 commit comments