@@ -569,39 +569,39 @@ def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
569
569
self ._tasks = {t .state_index : t for t in self ._generate_tasks ()}
570
570
return self ._tasks .values ()
571
571
572
- def get_jobs (
573
- self , index : int | None = None , as_array : bool = False
574
- ) -> "Task | StateArray[Task]" :
572
+ def get_jobs (self , final_index : int | None = None ) -> "Task | StateArray[Task]" :
575
573
"""Get the jobs that match a given state index.
576
574
577
575
Parameters
578
576
----------
579
- index : int, optional
580
- The index of the state of the task to get, by default None
581
- as_array : bool, optional
582
- Whether to return the tasks in a state-array object, by default if the index
583
- matches
577
+ final_index : int, optional
578
+ The index of the output state array (i.e. after any combinations) of the
579
+ job to get, by default None
584
580
585
581
Returns
586
582
-------
587
583
matching : Task | StateArray[Task]
588
584
The task or tasks that match the given index
589
585
"""
590
- matching = StateArray ()
591
- if self .tasks :
592
- try :
593
- task = self ._tasks [index ]
594
- except KeyError :
595
- if index is None :
596
- return StateArray (self ._tasks .values ())
597
- # Select matching tasks and return them in nested state-array objects
598
- for ind , task in self ._tasks .items ():
599
- matching .append (task )
600
- else :
601
- if not as_array :
602
- return task
603
- matching .append (task )
604
- return matching
586
+ if not self .tasks : # No jobs, return empty state array
587
+ return StateArray ()
588
+ if not self .node .state : # Return the singular job
589
+ assert final_index is None
590
+ task = self ._tasks [None ]
591
+ return task
592
+ if final_index is None : # return all jobs in a state array
593
+ return StateArray (self ._tasks .values ())
594
+ if not self .node .state .combiner : # Select the job that matches the index
595
+ task = self ._tasks [final_index ]
596
+ return task
597
+ # Get a slice of the tasks that match the given index of the state array of the
598
+ # combined values
599
+ final_index = set (self .node .state .states_ind_final [final_index ].items ())
600
+ return StateArray (
601
+ self ._tasks [i ]
602
+ for i , ind in enumerate (self .node .state .states_ind )
603
+ if set (ind .items ()).issuperset (final_index )
604
+ )
605
605
606
606
@property
607
607
def started (self ) -> bool :
@@ -762,9 +762,23 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
762
762
for index , task in list (self .blocked .items ()):
763
763
pred : NodeExecution
764
764
is_runnable = True
765
+ states_ind = (
766
+ list (self .node .state .states_ind [index ].items ())
767
+ if self .node .state
768
+ else []
769
+ )
765
770
for pred in graph .predecessors [self .node .name ]:
766
- pred_jobs : StateArray [Task ] = pred .get_jobs (index , as_array = True )
767
- pred_inds = [j .state_index for j in pred_jobs ]
771
+ if pred .node .state :
772
+ pred_states_ind = {
773
+ (k , i ) for k , i in states_ind if k .startswith (pred .name + "." )
774
+ }
775
+ pred_inds = [
776
+ i
777
+ for i , ind in enumerate (pred .node .state .states_ind )
778
+ if set (ind .items ()).issuperset (pred_states_ind )
779
+ ]
780
+ else :
781
+ pred_inds = [None ]
768
782
if not all (i in pred .successful for i in pred_inds ):
769
783
is_runnable = False
770
784
blocked = True
0 commit comments