Skip to content

Commit b24c39a

Browse files
authored
Arm backend: Tag control flow submodules in partitioner (#15364)
Control flow operators are implemented with get_attr/ placeholders nodes pointing to submodules in the graph_module. To know if we can handle such an operators, we need to figure out whether we can partition all of the submodule the operator points to. We can do this simply by applying the partition logic we already have to submodules. This has the added benefit that we will partition parts of submodules into normal partitions, even if we can't partition the full submodule. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 9075855 commit b24c39a

File tree

1 file changed

+106
-67
lines changed

1 file changed

+106
-67
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 106 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import logging
18+
from itertools import count
1819
from typing import Callable, List, Optional, Sequence, Tuple
1920

2021
import torch
@@ -35,8 +36,10 @@
3536
)
3637
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
3738
from executorch.exir.dialects._ops import ops as exir_ops
39+
from executorch.exir.graph_module import get_control_flow_submodules
3840
from torch.export.exported_program import ExportedProgram
39-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
41+
from torch.fx import GraphModule
42+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
4043
from torch.fx.passes.operator_support import OperatorSupportBase
4144

4245
logger = logging.getLogger(__name__)
@@ -110,6 +113,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
110113
return all(m == 1 for m in multiples)
111114

112115

116+
def is_partitioned(
117+
node: torch.fx.Node,
118+
tag: str,
119+
) -> bool:
120+
"""Return True if the node currently belongs to the partition ``tag``.
121+
122+
Args:
123+
node (torch.fx.Node): FX node to check.
124+
tag (str): Delegation tag identifying the partition.
125+
126+
Returns:
127+
bool: True if the node carries the matching delegation tag.
128+
129+
"""
130+
return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
131+
132+
133+
def reject_partition(
134+
reason: str, partition: Partition, reporter: WhyNoPartitionReporter
135+
) -> None:
136+
"""Remove a proposed partition and record the rejection reason.
137+
138+
Args:
139+
reason (str): Human-readable explanation for rejection.
140+
partition (object): Proposed partition object from the
141+
capability partitioner.
142+
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
143+
"""
144+
for node in partition.nodes:
145+
if "delegation_tag" in node.meta:
146+
del node.meta["delegation_tag"]
147+
reporter.report_reject(
148+
node,
149+
reason,
150+
)
151+
152+
113153
class TOSAPartitioner(Partitioner):
114154
"""Partition an exported program into TOSA-delegable subgraphs.
115155
@@ -142,107 +182,76 @@ def __init__(
142182
self.additional_checks = additional_checks
143183
self.tosa_spec = compile_spec.tosa_spec
144184

145-
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
146-
"""Partition the program and tag TOSA-compatible subgraphs.
147-
148-
Run the FX capability-based partitioner to propose subgraphs, then
149-
refine tags by removing boundary-only quantize/dequantize nodes and by
150-
rejecting partitions that would lower to no-ops. Emit a detailed report
151-
of rejected nodes and their reasons.
185+
def _tag_module( # noqa
186+
self,
187+
module: GraphModule,
188+
containing_program: ExportedProgram,
189+
reporter: WhyNoPartitionReporter,
190+
tag_iterator: count | None = None,
191+
) -> set[str]:
192+
"""Tag nodes in a module, possibly a submodule, from the containing program.
152193
153194
Args:
154-
exported_program (ExportedProgram): Program to analyze and
155-
partition.
156-
195+
module: a GraphModule from `containing_program` to tag nodes in.
196+
containing_program: The ExportedProgram that contains the module.
197+
reporter: A reporter to report why nodes were rejected.
157198
Returns:
158-
PartitionResult: The input program with nodes tagged for delegation
159-
and a mapping of partition tags to delegation specs.
160-
199+
A set of strings with the partition tags.
161200
"""
162-
logger.info("TOSAPartitioner::partition")
163-
partition_tags: dict[str, DelegationSpec] = {}
164-
165-
logger.info(
166-
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
167-
)
168-
169-
reporter = WhyNoPartitionReporter()
201+
tags: set[str] = set()
202+
if tag_iterator is None:
203+
tag_iterator = count(0)
204+
for _, submodule, _ in get_control_flow_submodules(module):
205+
submodule_tags = self._tag_module(
206+
submodule, containing_program, reporter, tag_iterator
207+
)
208+
if len(tags & submodule_tags) != 0:
209+
raise RuntimeError(
210+
"Got overlapping tags in two different modules, this shouldn't happen."
211+
)
212+
tags = tags | submodule_tags
170213
operator_support = tosa_support_factory(
171-
self.tosa_spec, exported_program, reporter, self.additional_checks
214+
self.tosa_spec, containing_program, reporter, self.additional_checks
172215
)
173216
capability_partitioner = CapabilityBasedPartitioner(
174-
exported_program.graph_module,
217+
module,
175218
operator_support,
176219
allows_single_node_partition=True,
177220
)
178221
partition_list = capability_partitioner.propose_partitions()
179222

180-
def reject_partition(reason: str, partition, tag) -> None:
181-
"""Remove a proposed partition and record the rejection reason.
182-
183-
Args:
184-
reason (str): Human-readable explanation for rejection.
185-
partition (object): Proposed partition object from the
186-
capability partitioner.
187-
tag (str): Delegation tag associated with the partition.
188-
189-
"""
190-
for node in partition.nodes:
191-
if "delegation_tag" in node.meta:
192-
del node.meta["delegation_tag"]
193-
reporter.report_reject(
194-
node,
195-
reason,
196-
)
197-
partition_tags.pop(tag, None)
198-
199223
for partition in partition_list:
200-
tag = f"tag{partition.id}"
201-
202-
def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
203-
"""Return True if the node currently belongs to the partition ``tag``.
204-
205-
Args:
206-
node (torch.fx.Node): FX node to check.
207-
tag (str): Delegation tag identifying the partition.
208-
209-
Returns:
210-
bool: True if the node carries the matching delegation tag.
211-
212-
"""
213-
return (
214-
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
215-
)
224+
tag = f"tag{next(tag_iterator)}"
225+
tags.add(tag)
216226

217227
for node in partition.nodes:
218228
node.meta["delegation_tag"] = tag
219-
partition_tags[tag] = self.delegation_spec
220229

221230
# De-tag outermost q-nodes upwards and dq-nodes downwards.
222231
# De-tag if at least one input/output is not part of the partition.
223-
for node in exported_program.graph_module.graph.nodes:
224-
if not is_partitioned(node):
232+
for node in module.graph.nodes:
233+
if not is_partitioned(node, tag):
225234
continue
226235
if node.target in Q_OPS:
227236
for input in node.all_input_nodes:
228-
if not is_partitioned(input):
237+
if not is_partitioned(input, tag):
229238
del node.meta["delegation_tag"]
230239
break
231240
continue
232241

233242
if node.target in DQ_OPS:
234243
for user in node.users:
235-
if not is_partitioned(user):
244+
if not is_partitioned(user, tag):
236245
del node.meta["delegation_tag"]
237246
break
238247
continue
239248

240249
if self.tosa_spec.support_float():
241250
continue
242251

243-
if is_partitioned(node):
252+
if is_partitioned(node, tag):
244253
for input in node.all_input_nodes:
245-
if is_partitioned(input):
254+
if is_partitioned(input, tag):
246255
continue
247256
if get_first_fake_tensor(input).dtype.is_floating_point:
248257
reporter.report_reject(
@@ -265,8 +274,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
265274
reject_partition(
266275
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
267276
partition,
268-
tag,
277+
reporter,
269278
)
279+
tags.remove(tag)
280+
return tags
281+
282+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
283+
"""Partition the program and tag TOSA-compatible subgraphs.
284+
285+
Run the FX capability-based partitioner to propose subgraphs, then
286+
refine tags by removing boundary-only quantize/dequantize nodes and by
287+
rejecting partitions that would lower to no-ops. Emit a detailed report
288+
of rejected nodes and their reasons.
289+
290+
Args:
291+
exported_program (ExportedProgram): Program to analyze and
292+
partition.
293+
294+
Returns:
295+
PartitionResult: The input program with nodes tagged for delegation
296+
and a mapping of partition tags to delegation specs.
297+
298+
"""
299+
logger.info("TOSAPartitioner::partition")
300+
logger.info(
301+
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
302+
)
303+
304+
reporter = WhyNoPartitionReporter()
305+
tags = self._tag_module(
306+
exported_program.graph_module, exported_program, reporter
307+
)
308+
partition_tags = {tag: self.delegation_spec for tag in tags}
270309

271310
tag_constant_data(exported_program)
272311
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")

0 commit comments

Comments
 (0)