From 3ffcd188bce3467701c935bb6cfc9a3c822711e7 Mon Sep 17 00:00:00 2001 From: Michael Adkins <michael@prefect.io> Date: Mon, 10 Oct 2022 12:42:35 -0500 Subject: [PATCH 1/3] Add `allow_failure` annotation to allow failed runs to be passed downstream --- src/prefect/__init__.py | 2 +- src/prefect/engine.py | 18 ++++- src/prefect/utilities/annotations.py | 31 +++++++- tests/test_tasks.py | 106 ++++++++++++++++++++++++++- 4 files changed, 149 insertions(+), 8 deletions(-) diff --git a/src/prefect/__init__.py b/src/prefect/__init__.py index 6aa1181c6265..e930f9b19fb2 100644 --- a/src/prefect/__init__.py +++ b/src/prefect/__init__.py @@ -28,7 +28,7 @@ from prefect.context import tags from prefect.client import get_client from prefect.manifests import Manifest -from prefect.utilities.annotations import unmapped +from prefect.utilities.annotations import unmapped, allow_failure # Import modules that register types import prefect.serializers diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 469cb869cc45..7a3e343788e9 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -86,7 +86,7 @@ TaskConcurrencyType, ) from prefect.tasks import Task -from prefect.utilities.annotations import Quote, unmapped +from prefect.utilities.annotations import Quote, allow_failure, unmapped from prefect.utilities.asyncutils import ( gather, in_async_main_thread, @@ -825,6 +825,9 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun inputs = set() def add_futures_and_states_to_inputs(obj): + if isinstance(obj, allow_failure): + obj = obj.unwrap() + if isinstance(obj, PrefectFuture): run_async_from_worker_thread(obj._wait_for_submission) inputs.add(core.TaskRunResult(id=obj.task_run.id)) @@ -1386,6 +1389,11 @@ async def resolve_inputs( def resolve_input(expr): state = None + should_allow_failure = False + + if isinstance(expr, allow_failure): + expr = expr.unwrap() + should_allow_failure = True if isinstance(expr, Quote): return expr.unquote() @@ -1396,13 +1404,17 @@ def resolve_input(expr): else: return expr - if not state.is_completed(): + # Do not allow uncompleted upstreams except failures when `allow_failure` has + # been used + if not state.is_completed() and not ( + should_allow_failure and state.is_failed() + ): raise UpstreamTaskError( f"Upstream task run '{state.state_details.task_run_id}' did not reach a 'COMPLETED' state." ) # Only retrieve the result if requested as it may be expensive - return state.result() if return_data else None + return state.result(raise_on_failure=False) if return_data else None return await run_sync_in_worker_thread( visit_collection, diff --git a/src/prefect/utilities/annotations.py b/src/prefect/utilities/annotations.py index 05861a6988cd..eecc3970f9ab 100644 --- a/src/prefect/utilities/annotations.py +++ b/src/prefect/utilities/annotations.py @@ -4,14 +4,39 @@ class unmapped: - """A container that acts as an infinite array where all items are the same - value. Used for Task.map where there is a need to map over a single - value""" + """ + Wrapper for iterables. + + Indicates that this input should be sent as-is to all runs created during a mapping + operation instead of being split. + """ def __init__(self, value: Any): self.value = value def __getitem__(self, _) -> Any: + # Internally, this acts as an infinite array where all items are the same value + return self.value + + +class allow_failure: + """ + Wrapper for states or futures. + + Indicates that the upstream run for this input can be failed. + + Generally, Prefect will not allow a downstream run to start if any of its inputs + are failed. This annotation allows you to opt into receiving a failed input + downstream. + + If the input is from a failed run, the attached exception will be passed to your + function. + """ + + def __init__(self, value: Any): + self.value = value + + def unwrap(self) -> Any: return self.value diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7f9e797ecd21..6313e18b5872 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -23,7 +23,7 @@ from prefect.orion.schemas.states import State, StateType from prefect.tasks import Task, task, task_input_hash from prefect.testing.utilities import exceptions_equal, flaky_on_windows -from prefect.utilities.annotations import unmapped +from prefect.utilities.annotations import allow_failure, unmapped from prefect.utilities.collections import quote @@ -331,6 +331,62 @@ def bar(): assert bar() == "bar" + def test_downstream_does_not_run_if_upstream_fails(self): + @task + def fails(): + raise ValueError("Fail task!") + + @task + def bar(y): + return y + + @flow + def test_flow(): + f = fails.submit() + b = bar.submit(f) + return b + + flow_state = test_flow(return_state=True) + task_state = flow_state.result(raise_on_failure=False) + assert task_state.is_pending() + assert task_state.name == "NotReady" + + def test_downstream_runs_if_upstream_succeeds(self): + @task + def foo(x): + return x + + @task + def bar(y): + return y + 1 + + @flow + def test_flow(): + f = foo.submit(1) + b = bar.submit(f) + return b + + assert test_flow() == 2 + + def test_downstream_receives_exception_if_upstream_fails_and_allow_failure(self): + @task + def fails(): + raise ValueError("Fail task!") + + @task + def bar(y): + return y + + @flow + def test_flow(): + f = fails.submit() + b = bar.submit(allow_failure(f)) + return b.result() + + result = test_flow() + assert isinstance(result, ValueError) + assert "Fail task!" in str(result) + class TestTaskStates: @pytest.mark.parametrize("error", [ValueError("Hello"), None]) @@ -1447,6 +1503,35 @@ def test_flow(): x=[TaskRunResult(id=upstream_state.state_details.task_run_id)], ) + async def test_task_inputs_populated_with_state_upstream_wrapped_with_allow_failure( + self, orion_client + ): + @task + def upstream(x): + return x + + @task + def downstream(x): + return x + + @flow + def test_flow(): + upstream_state = upstream(1, return_state=True) + downstream_state = downstream( + allow_failure(upstream_state), return_state=True + ) + return upstream_state, downstream_state + + upstream_state, downstream_state = test_flow() + + task_run = await orion_client.read_task_run( + downstream_state.state_details.task_run_id + ) + + assert task_run.task_inputs == dict( + x=[TaskRunResult(id=upstream_state.state_details.task_run_id)], + ) + @pytest.mark.parametrize("result", [["Fred"], {"one": 1}, {1, 2, 2}, (1, 2)]) async def test_task_inputs_populated_with_collection_result_upstream( self, result, orion_client, flow_with_upstream_downstream @@ -1697,6 +1782,25 @@ def test_using_wait_for_in_task_definition_raises_reserved(self): def foo(wait_for): pass + def test_downstream_runs_if_upstream_fails_with_allow_failure_annotation(self): + @task + def fails(): + raise ValueError("Fail task!") + + @task + def bar(y): + return y + + @flow + def test_flow(): + f = fails.submit() + b = bar(2, wait_for=[allow_failure(f)], return_state=True) + return b + + flow_state = test_flow(return_state=True) + task_state = flow_state.result(raise_on_failure=False) + assert task_state.result() == 2 + @pytest.mark.enable_orion_handler class TestTaskRunLogs: From ac5ba1fd78b8fa26020793076618bc96e199d051 Mon Sep 17 00:00:00 2001 From: Michael Adkins <michael@prefect.io> Date: Mon, 10 Oct 2022 12:58:55 -0500 Subject: [PATCH 2/3] Fix test --- tests/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6313e18b5872..dc1917e7668b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -364,7 +364,7 @@ def bar(y): def test_flow(): f = foo.submit(1) b = bar.submit(f) - return b + return b.result() assert test_flow() == 2 From 8aa1d361d2bc0cc1a5ab34d550a24f07b74ea51b Mon Sep 17 00:00:00 2001 From: Michael Adkins <michael@prefect.io> Date: Mon, 10 Oct 2022 16:41:14 -0500 Subject: [PATCH 3/3] Add test coverage for `allow_failure` and subflows --- tests/test_flows.py | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/test_flows.py b/tests/test_flows.py index 4e9e0127ae2b..e8bed44336fe 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -33,6 +33,7 @@ flaky_on_windows, get_most_recent_flow_run, ) +from prefect.utilities.annotations import allow_failure, quote from prefect.utilities.collections import flatdict_to_dict from prefect.utilities.hashing import file_hash @@ -1074,6 +1075,62 @@ def parent_flow(): x=[TaskRunResult(id=task_state.state_details.task_run_id)], ) + async def test_subflow_with_one_upstream_task_future_and_allow_failure( + self, orion_client + ): + @task + def child_task(): + raise ValueError() + + @flow + def child_flow(x): + return x + + @flow + def parent_flow(): + future = child_task.submit() + flow_state = child_flow(x=allow_failure(future), return_state=True) + return quote((future.wait(), flow_state)) + + task_state, flow_state = parent_flow().unquote() + assert isinstance(flow_state.result(), ValueError) + flow_tracking_task_run = await orion_client.read_task_run( + flow_state.state_details.task_run_id + ) + + assert task_state.is_failed() + assert flow_tracking_task_run.task_inputs == dict( + x=[TaskRunResult(id=task_state.state_details.task_run_id)], + ) + + async def test_subflow_with_one_upstream_task_state_and_allow_failure( + self, orion_client + ): + @task + def child_task(): + raise ValueError() + + @flow + def child_flow(x): + return x + + @flow + def parent_flow(): + task_state = child_task(return_state=True) + flow_state = child_flow(x=allow_failure(task_state), return_state=True) + return quote((task_state, flow_state)) + + task_state, flow_state = parent_flow().unquote() + assert isinstance(flow_state.result(), ValueError) + flow_tracking_task_run = await orion_client.read_task_run( + flow_state.state_details.task_run_id + ) + + assert task_state.is_failed() + assert flow_tracking_task_run.task_inputs == dict( + x=[TaskRunResult(id=task_state.state_details.task_run_id)], + ) + async def test_subflow_with_no_upstream_tasks(self, orion_client): @flow def bar(x, y):