Skip to content

Commit

Permalink
Add option to eval workflow condition as workflow only.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Apr 25, 2024
1 parent e22d2bd commit bb7b363
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 0 additions & 2 deletions law/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ def input(self):
if self.cache_requirements:
if self._cached_requirements is no_value:
self._cached_requirements = self.requires()
else:
print("BASETASK.INPUT() TAKING REQS FROM CACHE BITCHES")
reqs = self._cached_requirements
else:
reqs = self.requires()
Expand Down
28 changes: 21 additions & 7 deletions law/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def dynamic_workflow_condition(
create_branch_map_fn=None,
requires_fn=None,
output_fn=None,
condition_as_workflow=False,
cache_met_condition=True,
):
"""
Decorator factory that is meant to wrap a workflow methods that defines a dynamic workflow
Expand All @@ -255,6 +257,8 @@ def decorator(condition_fn):
create_branch_map_fn=create_branch_map_fn,
requires_fn=requires_fn,
output_fn=output_fn,
condition_as_workflow=condition_as_workflow,
cache_met_condition=cache_met_condition,
)

return decorator if condition_fn is None else decorator(condition_fn)
Expand Down Expand Up @@ -319,6 +323,11 @@ def run(self):
As a consequence, the amended workflow is fully dynamic with its exact shape potentially
depending heavily on conditions that are only known at runtime.
Internally, the condition is evaluated by the calling task which is usually a workflow, but it
can also be one of its branch tasks if, for instance, sandboxing is involved. Set
*condition_as_workflow* to *True* to ensure that the condition is always evaluated by the
workflow itself.
In case the ``workflow_condition`` involves a costly computation, it is recommended to cache
evluation of the condition by setting *cache_met_condition* argument to *True* or a string
denoting the task instance attribute where the met condition is stored. In the first case,
Expand All @@ -333,15 +342,17 @@ def __init__(
create_branch_map_fn=None,
requires_fn=None,
output_fn=None,
condition_as_workflow=False,
cache_met_condition=True,
):
super().__init__()
super(DynamicWorkflowCondition, self).__init__()

# attributes
self._condition_fn = condition_fn
self._create_branch_map_fn = create_branch_map_fn
self._requires_fn = requires_fn
self._output_fn = output_fn
self.condition_as_workflow = condition_as_workflow
self.cache_met_condition = cache_met_condition
if self.cache_met_condition and not isinstance(cache_met_condition, str):
self.cache_met_condition = "_dynamic_workflow_condition_met"
Expand All @@ -357,7 +368,8 @@ def condition(inst, *args, **kwargs):
return getattr(inst, self.cache_met_condition)

# evaluate the condition
is_met = self._condition_fn(inst.as_workflow(), *args, **kwargs)
task = inst.as_workflow() if self.condition_as_workflow else inst
is_met = self._condition_fn(task, *args, **kwargs)

# write to cache if requested
if self.cache_met_condition and is_met:
Expand Down Expand Up @@ -680,6 +692,9 @@ def __new__(cls, *args, **kwargs):
bound_condition_fn = condition._wrap_condition_fn().__get__(inst)
setattr(inst, condition_attr, bound_condition_fn)

# store the condition object itself
setattr(inst, condition_attr + "_obj", condition)

# bind wrapped methods that currently correspond to placeholders
for attr, wrapper in condition._iter_wrappers(bound_condition_fn):
if getattr(inst, attr, None) != DynamicWorkflowCondition._decorator_result:
Expand Down Expand Up @@ -827,7 +842,7 @@ def get_branch_value(branch, branch_data, key):
# branch should not be set
if branch != -1:
raise ValueError(
"workflow parameters {} will lead to {} being a workflow, but branch "
"workflow parameters {} will lead to {} being a workflow, but branch "
"{} requested".format(wparams_repr(), cls.__name__, branch),
)

Expand Down Expand Up @@ -1355,12 +1370,11 @@ def workflow_requires(self):

def workflow_input(self):
"""
Returns the output targets if all workflow requirements, comparable to the normal
``input()`` method of plain tasks. When this method is called from a branch task, an
exception is raised.
Returns the output targets of all workflow requirements, comparable to the normal
``input()`` method of plain tasks.
"""
if self.is_branch():
raise Exception("calls to workflow_input are forbidden for branch tasks")
return self.as_workflow().workflow_input()

# get potentially cached workflow requirements
if self.cache_workflow_requirements:
Expand Down

0 comments on commit bb7b363

Please sign in to comment.