@@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
123123
124124 return True
125125
126- def _process_node_groups (
126+ def _process_all_nodes (
127127 self ,
128128 new_partition_id ,
129129 partitions_by_id ,
@@ -133,97 +133,60 @@ def _process_node_groups(
133133 partition_users ,
134134 partition_map ,
135135 ):
136- """Process nodes in predefined groups."""
137- group_to_partition_id = {}
138-
139- if not self .node_groups :
140- return group_to_partition_id
141-
142- processed_nodes = set ()
143-
144- # We have to create the partitions in reverse topological order
145- # so we find the groups as we traverse backwards in the graph
146- # this likely needs to be combined with the process_remaining_nodes
147- # TODO: this currently doesn't work with _process_remaining_nodes so
148- # if a user provides grouped nodes with operatorsupport, then this will
149- # faile
136+ """Process nodes into a partition."""
150137 for node in reversed (self .graph_module .graph .nodes ):
151- if node not in self .node_to_group :
138+ if node in assignment or not self ._is_node_supported ( node ) :
152139 continue
153140
154- if node in processed_nodes :
155- continue
141+ if node in self .all_nodes_in_groups :
142+ group_idx = self .node_to_group [node ]
143+ group = self .node_groups [group_idx ]
156144
157- group_idx = self .node_to_group [node ]
158- group = self .node_groups [group_idx ]
159-
160- # Create a partition for group
161- partition_id = next (new_partition_id )
162- partition = Partition (id = partition_id , nodes = set ())
163- partitions_by_id [partition_id ] = partition
164- partitions_order [partition_id ] = partition_id
165- group_to_partition_id [group_idx ] = partition_id
166-
167- # Add all supported nodes from the group to the partition
168- for node in group :
169- if self ._is_node_supported (node ):
170- partition .add_node (node )
171- assignment [node ] = partition_id
172- nodes_order [node ] = partition_id
173-
174- # Set partition users
175- partition_users [partition_id ] = {
176- user
177- for node in partition .nodes
178- for user in node .users
179- if user not in partition .nodes
180- }
181-
182- # Update partition map
183- for node in partition .nodes :
145+ # Create a partition for group
146+ partition_id = next (new_partition_id )
147+ partition = Partition (id = partition_id , nodes = set ())
148+ partitions_by_id [partition_id ] = partition
149+ partitions_order [partition_id ] = partition_id
150+
151+ # Add all supported nodes from the group to the partition
152+ for node in group :
153+ if self ._is_node_supported (node ):
154+ partition .add_node (node )
155+ assignment [node ] = partition_id
156+ nodes_order [node ] = partition_id
157+
158+ # Set partition users
159+ partition_users [partition_id ] = {
160+ user
161+ for node in partition .nodes
162+ for user in node .users
163+ if user not in partition .nodes
164+ }
165+
166+ # Update partition map
167+ for node in partition .nodes :
168+ for user in node .users :
169+ target_id = assignment .get (user , None )
170+ if target_id is not None and target_id != partition_id :
171+ partition_map [partition_id ].add (target_id )
172+ partition_map [partition_id ].update (partition_map [target_id ])
173+ else :
174+ partition_id = next (new_partition_id )
175+ nodes_order [node ] = partition_id
176+ partitions_order [partition_id ] = partition_id
177+ partitions_by_id [partition_id ] = Partition (
178+ id = partition_id , nodes = [node ]
179+ )
180+ assignment [node ] = partition_id
181+ partition_users [partition_id ] = set (node .users )
182+
183+ # Update partition map
184184 for user in node .users :
185185 target_id = assignment .get (user )
186- if target_id is not None and target_id != partition_id :
186+ if target_id is not None :
187187 partition_map [partition_id ].add (target_id )
188188 partition_map [partition_id ].update (partition_map [target_id ])
189189
190- # all the nodes in the group have now been processed
191- # so skip if we encoutner them again in our rev topo
192- # iteration
193- for node in group :
194- processed_nodes .add (node )
195-
196- return group_to_partition_id
197-
198- def _process_remaining_nodes (
199- self ,
200- new_partition_id ,
201- partitions_by_id ,
202- assignment ,
203- nodes_order ,
204- partitions_order ,
205- partition_users ,
206- partition_map ,
207- ):
208- """Process nodes not in any predefined group."""
209- for node in reversed (self .graph_module .graph .nodes ):
210- if node in assignment or not self ._is_node_supported (node ):
211- continue
212-
213- partition_id = next (new_partition_id )
214- nodes_order [node ] = partition_id
215- partitions_order [partition_id ] = partition_id
216- partitions_by_id [partition_id ] = Partition (id = partition_id , nodes = [node ])
217- assignment [node ] = partition_id
218- partition_users [partition_id ] = set (node .users )
219-
220- # Update partition map
221- for user in node .users :
222- target_id = assignment .get (user )
223- if target_id is not None :
224- partition_map [partition_id ].add (target_id )
225- partition_map [partition_id ].update (partition_map [target_id ])
226-
227190 def _merge_partitions (
228191 self ,
229192 partitions_by_id ,
@@ -378,19 +341,8 @@ def propose_partitions(self) -> list[Partition]:
378341 partition_users = {} # Maps partition IDs to partition users
379342 new_partition_id = itertools .count ()
380343
381- # Process nodes in predefined groups
382- self ._process_node_groups (
383- new_partition_id ,
384- partitions_by_id ,
385- assignment ,
386- nodes_order ,
387- partitions_order ,
388- partition_users ,
389- partition_map ,
390- )
391-
392- # Process remaining nodes
393- self ._process_remaining_nodes (
344+ # Process all nodes into partitions
345+ self ._process_all_nodes (
394346 new_partition_id ,
395347 partitions_by_id ,
396348 assignment ,
0 commit comments