Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow_failure annotation to allow failed runs to be passed downstream #7120

Merged
merged 3 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/prefect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from prefect.context import tags
from prefect.client import get_client
from prefect.manifests import Manifest
from prefect.utilities.annotations import unmapped
from prefect.utilities.annotations import unmapped, allow_failure

# Import modules that register types
import prefect.serializers
Expand Down
18 changes: 15 additions & 3 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
TaskConcurrencyType,
)
from prefect.tasks import Task
from prefect.utilities.annotations import Quote, unmapped
from prefect.utilities.annotations import Quote, allow_failure, unmapped
from prefect.utilities.asyncutils import (
gather,
in_async_main_thread,
Expand Down Expand Up @@ -825,6 +825,9 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun
inputs = set()

def add_futures_and_states_to_inputs(obj):
if isinstance(obj, allow_failure):
obj = obj.unwrap()

if isinstance(obj, PrefectFuture):
run_async_from_worker_thread(obj._wait_for_submission)
inputs.add(core.TaskRunResult(id=obj.task_run.id))
Expand Down Expand Up @@ -1386,6 +1389,11 @@ async def resolve_inputs(

def resolve_input(expr):
state = None
should_allow_failure = False

if isinstance(expr, allow_failure):
expr = expr.unwrap()
should_allow_failure = True

if isinstance(expr, Quote):
return expr.unquote()
Expand All @@ -1396,13 +1404,17 @@ def resolve_input(expr):
else:
return expr

if not state.is_completed():
# Do not allow uncompleted upstreams except failures when `allow_failure` has
# been used
if not state.is_completed() and not (
should_allow_failure and state.is_failed()
):
raise UpstreamTaskError(
f"Upstream task run '{state.state_details.task_run_id}' did not reach a 'COMPLETED' state."
)

# Only retrieve the result if requested as it may be expensive
return state.result() if return_data else None
return state.result(raise_on_failure=False) if return_data else None

return await run_sync_in_worker_thread(
visit_collection,
Expand Down
31 changes: 28 additions & 3 deletions src/prefect/utilities/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@


class unmapped:
"""A container that acts as an infinite array where all items are the same
value. Used for Task.map where there is a need to map over a single
value"""
"""
Wrapper for iterables.

Indicates that this input should be sent as-is to all runs created during a mapping
operation instead of being split.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for updating this.

"""

def __init__(self, value: Any):
self.value = value

def __getitem__(self, _) -> Any:
# Internally, this acts as an infinite array where all items are the same value
return self.value


class allow_failure:
"""
Wrapper for states or futures.

Indicates that the upstream run for this input can be failed.

Generally, Prefect will not allow a downstream run to start if any of its inputs
are failed. This annotation allows you to opt into receiving a failed input
downstream.

If the input is from a failed run, the attached exception will be passed to your
function.
"""

def __init__(self, value: Any):
self.value = value

def unwrap(self) -> Any:
return self.value


Expand Down
57 changes: 57 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
flaky_on_windows,
get_most_recent_flow_run,
)
from prefect.utilities.annotations import allow_failure, quote
from prefect.utilities.collections import flatdict_to_dict
from prefect.utilities.hashing import file_hash

Expand Down Expand Up @@ -1074,6 +1075,62 @@ def parent_flow():
x=[TaskRunResult(id=task_state.state_details.task_run_id)],
)

async def test_subflow_with_one_upstream_task_future_and_allow_failure(
self, orion_client
):
@task
def child_task():
raise ValueError()

@flow
def child_flow(x):
return x

@flow
def parent_flow():
future = child_task.submit()
flow_state = child_flow(x=allow_failure(future), return_state=True)
return quote((future.wait(), flow_state))

task_state, flow_state = parent_flow().unquote()
assert isinstance(flow_state.result(), ValueError)
flow_tracking_task_run = await orion_client.read_task_run(
flow_state.state_details.task_run_id
)

assert task_state.is_failed()
assert flow_tracking_task_run.task_inputs == dict(
x=[TaskRunResult(id=task_state.state_details.task_run_id)],
)

async def test_subflow_with_one_upstream_task_state_and_allow_failure(
self, orion_client
):
@task
def child_task():
raise ValueError()

@flow
def child_flow(x):
return x

@flow
def parent_flow():
task_state = child_task(return_state=True)
flow_state = child_flow(x=allow_failure(task_state), return_state=True)
return quote((task_state, flow_state))

task_state, flow_state = parent_flow().unquote()
assert isinstance(flow_state.result(), ValueError)
flow_tracking_task_run = await orion_client.read_task_run(
flow_state.state_details.task_run_id
)

assert task_state.is_failed()
assert flow_tracking_task_run.task_inputs == dict(
x=[TaskRunResult(id=task_state.state_details.task_run_id)],
)

async def test_subflow_with_no_upstream_tasks(self, orion_client):
@flow
def bar(x, y):
Expand Down
106 changes: 105 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from prefect.orion.schemas.states import State, StateType
from prefect.tasks import Task, task, task_input_hash
from prefect.testing.utilities import exceptions_equal, flaky_on_windows
from prefect.utilities.annotations import unmapped
from prefect.utilities.annotations import allow_failure, unmapped
from prefect.utilities.collections import quote


Expand Down Expand Up @@ -331,6 +331,62 @@ def bar():

assert bar() == "bar"

def test_downstream_does_not_run_if_upstream_fails(self):
@task
def fails():
raise ValueError("Fail task!")

@task
def bar(y):
return y

@flow
def test_flow():
f = fails.submit()
b = bar.submit(f)
return b

flow_state = test_flow(return_state=True)
task_state = flow_state.result(raise_on_failure=False)
assert task_state.is_pending()
assert task_state.name == "NotReady"

def test_downstream_runs_if_upstream_succeeds(self):
@task
def foo(x):
return x

@task
def bar(y):
return y + 1

@flow
def test_flow():
f = foo.submit(1)
b = bar.submit(f)
return b.result()

assert test_flow() == 2

def test_downstream_receives_exception_if_upstream_fails_and_allow_failure(self):
@task
def fails():
raise ValueError("Fail task!")

@task
def bar(y):
return y

@flow
def test_flow():
f = fails.submit()
b = bar.submit(allow_failure(f))
return b.result()

result = test_flow()
assert isinstance(result, ValueError)
assert "Fail task!" in str(result)


class TestTaskStates:
@pytest.mark.parametrize("error", [ValueError("Hello"), None])
Expand Down Expand Up @@ -1447,6 +1503,35 @@ def test_flow():
x=[TaskRunResult(id=upstream_state.state_details.task_run_id)],
)

async def test_task_inputs_populated_with_state_upstream_wrapped_with_allow_failure(
self, orion_client
):
@task
def upstream(x):
return x

@task
def downstream(x):
return x

@flow
def test_flow():
upstream_state = upstream(1, return_state=True)
downstream_state = downstream(
allow_failure(upstream_state), return_state=True
)
return upstream_state, downstream_state

upstream_state, downstream_state = test_flow()

task_run = await orion_client.read_task_run(
downstream_state.state_details.task_run_id
)

assert task_run.task_inputs == dict(
x=[TaskRunResult(id=upstream_state.state_details.task_run_id)],
)

@pytest.mark.parametrize("result", [["Fred"], {"one": 1}, {1, 2, 2}, (1, 2)])
async def test_task_inputs_populated_with_collection_result_upstream(
self, result, orion_client, flow_with_upstream_downstream
Expand Down Expand Up @@ -1697,6 +1782,25 @@ def test_using_wait_for_in_task_definition_raises_reserved(self):
def foo(wait_for):
pass

def test_downstream_runs_if_upstream_fails_with_allow_failure_annotation(self):
@task
def fails():
raise ValueError("Fail task!")

@task
def bar(y):
return y

@flow
def test_flow():
f = fails.submit()
b = bar(2, wait_for=[allow_failure(f)], return_state=True)
return b

flow_state = test_flow(return_state=True)
task_state = flow_state.result(raise_on_failure=False)
assert task_state.result() == 2


@pytest.mark.enable_orion_handler
class TestTaskRunLogs:
Expand Down