Skip to content

Commit 008a014

Browse files
committed
Revert "Arm backend: Tag control flow submodules in partitioner (pytorch#15364)"
This reverts commit b24c39a.
1 parent 8167327 commit 008a014

File tree

1 file changed

+67
-106
lines changed

1 file changed

+67
-106
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 67 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
import logging
18-
from itertools import count
1918
from typing import Callable, List, Optional, Sequence, Tuple
2019

2120
import torch
@@ -36,10 +35,8 @@
3635
)
3736
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
3837
from executorch.exir.dialects._ops import ops as exir_ops
39-
from executorch.exir.graph_module import get_control_flow_submodules
4038
from torch.export.exported_program import ExportedProgram
41-
from torch.fx import GraphModule
42-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
39+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4340
from torch.fx.passes.operator_support import OperatorSupportBase
4441

4542
logger = logging.getLogger(__name__)
@@ -113,43 +110,6 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
113110
return all(m == 1 for m in multiples)
114111

115112

116-
def is_partitioned(
117-
node: torch.fx.Node,
118-
tag: str,
119-
) -> bool:
120-
"""Return True if the node currently belongs to the partition ``tag``.
121-
122-
Args:
123-
node (torch.fx.Node): FX node to check.
124-
tag (str): Delegation tag identifying the partition.
125-
126-
Returns:
127-
bool: True if the node carries the matching delegation tag.
128-
129-
"""
130-
return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
131-
132-
133-
def reject_partition(
134-
reason: str, partition: Partition, reporter: WhyNoPartitionReporter
135-
) -> None:
136-
"""Remove a proposed partition and record the rejection reason.
137-
138-
Args:
139-
reason (str): Human-readable explanation for rejection.
140-
partition (object): Proposed partition object from the
141-
capability partitioner.
142-
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
143-
"""
144-
for node in partition.nodes:
145-
if "delegation_tag" in node.meta:
146-
del node.meta["delegation_tag"]
147-
reporter.report_reject(
148-
node,
149-
reason,
150-
)
151-
152-
153113
class TOSAPartitioner(Partitioner):
154114
"""Partition an exported program into TOSA-delegable subgraphs.
155115
@@ -182,76 +142,107 @@ def __init__(
182142
self.additional_checks = additional_checks
183143
self.tosa_spec = compile_spec.tosa_spec
184144

185-
def _tag_module( # noqa
186-
self,
187-
module: GraphModule,
188-
containing_program: ExportedProgram,
189-
reporter: WhyNoPartitionReporter,
190-
tag_iterator: count | None = None,
191-
) -> set[str]:
192-
"""Tag nodes in a module, possibly a submodule, from the containing program.
145+
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
146+
"""Partition the program and tag TOSA-compatible subgraphs.
147+
148+
Run the FX capability-based partitioner to propose subgraphs, then
149+
refine tags by removing boundary-only quantize/dequantize nodes and by
150+
rejecting partitions that would lower to no-ops. Emit a detailed report
151+
of rejected nodes and their reasons.
193152
194153
Args:
195-
module: a GraphModule from `containing_program` to tag nodes in.
196-
containing_program: The ExportedProgram that contains the module.
197-
reporter: A reporter to report why nodes were rejected.
154+
exported_program (ExportedProgram): Program to analyze and
155+
partition.
156+
198157
Returns:
199-
A set of strings with the partition tags.
158+
PartitionResult: The input program with nodes tagged for delegation
159+
and a mapping of partition tags to delegation specs.
160+
200161
"""
201-
tags: set[str] = set()
202-
if tag_iterator is None:
203-
tag_iterator = count(0)
204-
for _, submodule, _ in get_control_flow_submodules(module):
205-
submodule_tags = self._tag_module(
206-
submodule, containing_program, reporter, tag_iterator
207-
)
208-
if len(tags & submodule_tags) != 0:
209-
raise RuntimeError(
210-
"Got overlapping tags in two different modules, this shouldn't happen."
211-
)
212-
tags = tags | submodule_tags
162+
logger.info("TOSAPartitioner::partition")
163+
partition_tags: dict[str, DelegationSpec] = {}
164+
165+
logger.info(
166+
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
167+
)
168+
169+
reporter = WhyNoPartitionReporter()
213170
operator_support = tosa_support_factory(
214-
self.tosa_spec, containing_program, reporter, self.additional_checks
171+
self.tosa_spec, exported_program, reporter, self.additional_checks
215172
)
216173
capability_partitioner = CapabilityBasedPartitioner(
217-
module,
174+
exported_program.graph_module,
218175
operator_support,
219176
allows_single_node_partition=True,
220177
)
221178
partition_list = capability_partitioner.propose_partitions()
222179

180+
def reject_partition(reason: str, partition, tag) -> None:
181+
"""Remove a proposed partition and record the rejection reason.
182+
183+
Args:
184+
reason (str): Human-readable explanation for rejection.
185+
partition (object): Proposed partition object from the
186+
capability partitioner.
187+
tag (str): Delegation tag associated with the partition.
188+
189+
"""
190+
for node in partition.nodes:
191+
if "delegation_tag" in node.meta:
192+
del node.meta["delegation_tag"]
193+
reporter.report_reject(
194+
node,
195+
reason,
196+
)
197+
partition_tags.pop(tag, None)
198+
223199
for partition in partition_list:
224-
tag = f"tag{next(tag_iterator)}"
225-
tags.add(tag)
200+
tag = f"tag{partition.id}"
201+
202+
def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
203+
"""Return True if the node currently belongs to the partition ``tag``.
204+
205+
Args:
206+
node (torch.fx.Node): FX node to check.
207+
tag (str): Delegation tag identifying the partition.
208+
209+
Returns:
210+
bool: True if the node carries the matching delegation tag.
211+
212+
"""
213+
return (
214+
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
215+
)
226216

227217
for node in partition.nodes:
228218
node.meta["delegation_tag"] = tag
219+
partition_tags[tag] = self.delegation_spec
229220

230221
# De-tag outermost q-nodes upwards and dq-nodes downwards.
231222
# De-tag if at least one input/output is not part of the partition.
232-
for node in module.graph.nodes:
233-
if not is_partitioned(node, tag):
223+
for node in exported_program.graph_module.graph.nodes:
224+
if not is_partitioned(node):
234225
continue
235226
if node.target in Q_OPS:
236227
for input in node.all_input_nodes:
237-
if not is_partitioned(input, tag):
228+
if not is_partitioned(input):
238229
del node.meta["delegation_tag"]
239230
break
240231
continue
241232

242233
if node.target in DQ_OPS:
243234
for user in node.users:
244-
if not is_partitioned(user, tag):
235+
if not is_partitioned(user):
245236
del node.meta["delegation_tag"]
246237
break
247238
continue
248239

249240
if self.tosa_spec.support_float():
250241
continue
251242

252-
if is_partitioned(node, tag):
243+
if is_partitioned(node):
253244
for input in node.all_input_nodes:
254-
if is_partitioned(input, tag):
245+
if is_partitioned(input):
255246
continue
256247
if get_first_fake_tensor(input).dtype.is_floating_point:
257248
reporter.report_reject(
@@ -274,38 +265,8 @@ def _tag_module( # noqa
274265
reject_partition(
275266
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
276267
partition,
277-
reporter,
268+
tag,
278269
)
279-
tags.remove(tag)
280-
return tags
281-
282-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
283-
"""Partition the program and tag TOSA-compatible subgraphs.
284-
285-
Run the FX capability-based partitioner to propose subgraphs, then
286-
refine tags by removing boundary-only quantize/dequantize nodes and by
287-
rejecting partitions that would lower to no-ops. Emit a detailed report
288-
of rejected nodes and their reasons.
289-
290-
Args:
291-
exported_program (ExportedProgram): Program to analyze and
292-
partition.
293-
294-
Returns:
295-
PartitionResult: The input program with nodes tagged for delegation
296-
and a mapping of partition tags to delegation specs.
297-
298-
"""
299-
logger.info("TOSAPartitioner::partition")
300-
logger.info(
301-
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
302-
)
303-
304-
reporter = WhyNoPartitionReporter()
305-
tags = self._tag_module(
306-
exported_program.graph_module, exported_program, reporter
307-
)
308-
partition_tags = {tag: self.delegation_spec for tag in tags}
309270

310271
tag_constant_data(exported_program)
311272
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")

0 commit comments

Comments
 (0)