Skip to content

Commit d468f4b

Browse files
committed
Arm backend: Tag submodules in partitioner
Recursively search for control_flow_modules and tag them in _tag_module function in partitioner. Signed-off-by: Erik Lundell <erik.lundell@arm.com> Change-Id: Idb41e7f808013936dae8ca6909a69d053e834ca9
1 parent bb6dbc3 commit d468f4b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import logging
18+
from itertools import count
1819
from typing import Callable, List, Optional, Sequence, Tuple
1920

2021
import torch
@@ -35,6 +36,7 @@
3536
)
3637
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
3738
from executorch.exir.dialects._ops import ops as exir_ops
39+
from executorch.exir.graph_module import get_control_flow_submodules
3840
from torch.export.exported_program import ExportedProgram
3941
from torch.fx import GraphModule
4042
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
@@ -185,6 +187,7 @@ def _tag_module( # noqa
185187
module: GraphModule,
186188
containing_program: ExportedProgram,
187189
reporter: WhyNoPartitionReporter,
190+
tag_iterator: count | None = None,
188191
) -> set[str]:
189192
"""Tag nodes in a module, possibly a submodule, from the containing program.
190193
@@ -196,6 +199,17 @@ def _tag_module( # noqa
196199
A set of strings with the partition tags.
197200
"""
198201
tags: set[str] = set()
202+
if tag_iterator is None:
203+
tag_iterator = count(0)
204+
for _, submodule, _ in get_control_flow_submodules(module):
205+
submodule_tags = self._tag_module(
206+
submodule, containing_program, reporter, tag_iterator
207+
)
208+
if len(tags & submodule_tags) != 0:
209+
raise RuntimeError(
210+
"Got overlapping tags in two different modules, this shouldn't happen."
211+
)
212+
tags = tags | submodule_tags
199213
operator_support = tosa_support_factory(
200214
self.tosa_spec, containing_program, reporter, self.additional_checks
201215
)
@@ -207,7 +221,7 @@ def _tag_module( # noqa
207221
partition_list = capability_partitioner.propose_partitions()
208222

209223
for partition in partition_list:
210-
tag = f"tag{partition.id}"
224+
tag = f"tag{next(tag_iterator)}"
211225
tags.add(tag)
212226

213227
for node in partition.nodes:

0 commit comments

Comments
 (0)