1515"""
1616
1717import logging
18+ from itertools import count
1819from typing import Callable , List , Optional , Sequence , Tuple
1920
2021import torch
3536)
3637from executorch .exir .backend .utils import tag_constant_data , WhyNoPartitionReporter
3738from executorch .exir .dialects ._ops import ops as exir_ops
39+ from executorch .exir .graph_module import get_control_flow_submodules
3840from torch .export .exported_program import ExportedProgram
3941from torch .fx import GraphModule
4042from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
@@ -185,6 +187,7 @@ def _tag_module( # noqa
185187 module : GraphModule ,
186188 containing_program : ExportedProgram ,
187189 reporter : WhyNoPartitionReporter ,
190+ tag_iterator : count | None = None ,
188191 ) -> set [str ]:
189192 """Tag nodes in a module, possibly a submodule, from the containing program.
190193
@@ -196,6 +199,17 @@ def _tag_module( # noqa
196199 A set of strings with the partition tags.
197200 """
198201 tags : set [str ] = set ()
202+ if tag_iterator is None :
203+ tag_iterator = count (0 )
204+ for _ , submodule , _ in get_control_flow_submodules (module ):
205+ submodule_tags = self ._tag_module (
206+ submodule , containing_program , reporter , tag_iterator
207+ )
208+ if len (tags & submodule_tags ) != 0 :
209+ raise RuntimeError (
210+ "Got overlapping tags in two different modules, this shouldn't happen."
211+ )
212+ tags = tags | submodule_tags
199213 operator_support = tosa_support_factory (
200214 self .tosa_spec , containing_program , reporter , self .additional_checks
201215 )
@@ -207,7 +221,7 @@ def _tag_module( # noqa
207221 partition_list = capability_partitioner .propose_partitions ()
208222
209223 for partition in partition_list :
210- tag = f"tag{ partition . id } "
224+ tag = f"tag{ next ( tag_iterator ) } "
211225 tags .add (tag )
212226
213227 for node in partition .nodes :
0 commit comments