diff --git a/metaflow/task.py b/metaflow/task.py index e74302e2f2c..2a9e60ea5a2 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -180,6 +180,10 @@ def _init_foreach(self, step_name, join_type, inputs, split_index): # 2) join - pop the topmost frame from the stack # 3) step following a split - push a new frame in the stack + # We have a non-modifying case (case 4)) where we propagate the + # foreach-stack information to all tasks in the foreach. This is + # then used later to write the foreach-stack metadata for that task + # case 1) - reset the stack if step_name == "start": self.flow._foreach_stack = [] @@ -264,6 +268,9 @@ def lineage(): stack = inputs[0]["_foreach_stack"] stack.append(frame) self.flow._foreach_stack = stack + # case 4) - propagate in the foreach nest + elif "_foreach_stack" in inputs[0]: + self.flow._foreach_stack = inputs[0]["_foreach_stack"] def _clone_flow(self, datastore): x = self.flow.__class__(use_cli=False)