1515"""
1616
1717import logging
18+ from itertools import count
1819from typing import Callable , List , Optional , Sequence , Tuple
1920
2021import torch
3536)
3637from executorch .exir .backend .utils import tag_constant_data , WhyNoPartitionReporter
3738from executorch .exir .dialects ._ops import ops as exir_ops
39+ from executorch .exir .graph_module import get_control_flow_submodules
3840from torch .export .exported_program import ExportedProgram
39- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
41+ from torch .fx import GraphModule
42+ from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
4043from torch .fx .passes .operator_support import OperatorSupportBase
4144
4245logger = logging .getLogger (__name__ )
@@ -110,6 +113,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
110113 return all (m == 1 for m in multiples )
111114
112115
116+ def is_partitioned (
117+ node : torch .fx .Node ,
118+ tag : str ,
119+ ) -> bool :
120+ """Return True if the node currently belongs to the partition ``tag``.
121+
122+ Args:
123+ node (torch.fx.Node): FX node to check.
124+ tag (str): Delegation tag identifying the partition.
125+
126+ Returns:
127+ bool: True if the node carries the matching delegation tag.
128+
129+ """
130+ return "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
131+
132+
133+ def reject_partition (
134+ reason : str , partition : Partition , reporter : WhyNoPartitionReporter
135+ ) -> None :
136+ """Remove a proposed partition and record the rejection reason.
137+
138+ Args:
139+ reason (str): Human-readable explanation for rejection.
140+ partition (object): Proposed partition object from the
141+ capability partitioner.
142+ reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
143+ """
144+ for node in partition .nodes :
145+ if "delegation_tag" in node .meta :
146+ del node .meta ["delegation_tag" ]
147+ reporter .report_reject (
148+ node ,
149+ reason ,
150+ )
151+
152+
113153class TOSAPartitioner (Partitioner ):
114154 """Partition an exported program into TOSA-delegable subgraphs.
115155
@@ -142,107 +182,76 @@ def __init__(
142182 self .additional_checks = additional_checks
143183 self .tosa_spec = compile_spec .tosa_spec
144184
145- def partition (self , exported_program : ExportedProgram ) -> PartitionResult : # noqa
146- """Partition the program and tag TOSA-compatible subgraphs.
147-
148- Run the FX capability-based partitioner to propose subgraphs, then
149- refine tags by removing boundary-only quantize/dequantize nodes and by
150- rejecting partitions that would lower to no-ops. Emit a detailed report
151- of rejected nodes and their reasons.
185+ def _tag_module ( # noqa
186+ self ,
187+ module : GraphModule ,
188+ containing_program : ExportedProgram ,
189+ reporter : WhyNoPartitionReporter ,
190+ tag_iterator : count | None = None ,
191+ ) -> set [str ]:
192+ """Tag nodes in a module, possibly a submodule, from the containing program.
152193
153194 Args:
154- exported_program (ExportedProgram): Program to analyze and
155- partition .
156-
195+ module: a GraphModule from `containing_program` to tag nodes in.
196+ containing_program: The ExportedProgram that contains the module .
197+ reporter: A reporter to report why nodes were rejected.
157198 Returns:
158- PartitionResult: The input program with nodes tagged for delegation
159- and a mapping of partition tags to delegation specs.
160-
199+ A set of strings with the partition tags.
161200 """
162- logger .info ("TOSAPartitioner::partition" )
163- partition_tags : dict [str , DelegationSpec ] = {}
164-
165- logger .info (
166- f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
167- )
168-
169- reporter = WhyNoPartitionReporter ()
201+ tags : set [str ] = set ()
202+ if tag_iterator is None :
203+ tag_iterator = count (0 )
204+ for _ , submodule , _ in get_control_flow_submodules (module ):
205+ submodule_tags = self ._tag_module (
206+ submodule , containing_program , reporter , tag_iterator
207+ )
208+ if len (tags & submodule_tags ) != 0 :
209+ raise RuntimeError (
210+ "Got overlapping tags in two different modules, this shouldn't happen."
211+ )
212+ tags = tags | submodule_tags
170213 operator_support = tosa_support_factory (
171- self .tosa_spec , exported_program , reporter , self .additional_checks
214+ self .tosa_spec , containing_program , reporter , self .additional_checks
172215 )
173216 capability_partitioner = CapabilityBasedPartitioner (
174- exported_program . graph_module ,
217+ module ,
175218 operator_support ,
176219 allows_single_node_partition = True ,
177220 )
178221 partition_list = capability_partitioner .propose_partitions ()
179222
180- def reject_partition (reason : str , partition , tag ) -> None :
181- """Remove a proposed partition and record the rejection reason.
182-
183- Args:
184- reason (str): Human-readable explanation for rejection.
185- partition (object): Proposed partition object from the
186- capability partitioner.
187- tag (str): Delegation tag associated with the partition.
188-
189- """
190- for node in partition .nodes :
191- if "delegation_tag" in node .meta :
192- del node .meta ["delegation_tag" ]
193- reporter .report_reject (
194- node ,
195- reason ,
196- )
197- partition_tags .pop (tag , None )
198-
199223 for partition in partition_list :
200- tag = f"tag{ partition .id } "
201-
202- def is_partitioned (node : torch .fx .Node , tag = tag ) -> bool :
203- """Return True if the node currently belongs to the partition ``tag``.
204-
205- Args:
206- node (torch.fx.Node): FX node to check.
207- tag (str): Delegation tag identifying the partition.
208-
209- Returns:
210- bool: True if the node carries the matching delegation tag.
211-
212- """
213- return (
214- "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
215- )
224+ tag = f"tag{ next (tag_iterator )} "
225+ tags .add (tag )
216226
217227 for node in partition .nodes :
218228 node .meta ["delegation_tag" ] = tag
219- partition_tags [tag ] = self .delegation_spec
220229
221230 # De-tag outermost q-nodes upwards and dq-nodes downwards.
222231 # De-tag if at least one input/output is not part of the partition.
223- for node in exported_program . graph_module .graph .nodes :
224- if not is_partitioned (node ):
232+ for node in module .graph .nodes :
233+ if not is_partitioned (node , tag ):
225234 continue
226235 if node .target in Q_OPS :
227236 for input in node .all_input_nodes :
228- if not is_partitioned (input ):
237+ if not is_partitioned (input , tag ):
229238 del node .meta ["delegation_tag" ]
230239 break
231240 continue
232241
233242 if node .target in DQ_OPS :
234243 for user in node .users :
235- if not is_partitioned (user ):
244+ if not is_partitioned (user , tag ):
236245 del node .meta ["delegation_tag" ]
237246 break
238247 continue
239248
240249 if self .tosa_spec .support_float ():
241250 continue
242251
243- if is_partitioned (node ):
252+ if is_partitioned (node , tag ):
244253 for input in node .all_input_nodes :
245- if is_partitioned (input ):
254+ if is_partitioned (input , tag ):
246255 continue
247256 if get_first_fake_tensor (input ).dtype .is_floating_point :
248257 reporter .report_reject (
@@ -265,8 +274,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
265274 reject_partition (
266275 "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition." ,
267276 partition ,
268- tag ,
277+ reporter ,
269278 )
279+ tags .remove (tag )
280+ return tags
281+
282+ def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
283+ """Partition the program and tag TOSA-compatible subgraphs.
284+
285+ Run the FX capability-based partitioner to propose subgraphs, then
286+ refine tags by removing boundary-only quantize/dequantize nodes and by
287+ rejecting partitions that would lower to no-ops. Emit a detailed report
288+ of rejected nodes and their reasons.
289+
290+ Args:
291+ exported_program (ExportedProgram): Program to analyze and
292+ partition.
293+
294+ Returns:
295+ PartitionResult: The input program with nodes tagged for delegation
296+ and a mapping of partition tags to delegation specs.
297+
298+ """
299+ logger .info ("TOSAPartitioner::partition" )
300+ logger .info (
301+ f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
302+ )
303+
304+ reporter = WhyNoPartitionReporter ()
305+ tags = self ._tag_module (
306+ exported_program .graph_module , exported_program , reporter
307+ )
308+ partition_tags = {tag : self .delegation_spec for tag in tags }
270309
271310 tag_constant_data (exported_program )
272311 logger .info (f"The following nodes were rejected for { self .tosa_spec } :" )
0 commit comments