From d7069d90002a592b62e2f2d0200dd6843aea5805 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Wed, 4 Jan 2023 13:48:55 -0800 Subject: [PATCH 1/5] Add test for allow_failure and quote --- .../standard_test_suites/task_runners.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/prefect/testing/standard_test_suites/task_runners.py b/src/prefect/testing/standard_test_suites/task_runners.py index e346e7fefa9a..579b39bde550 100644 --- a/src/prefect/testing/standard_test_suites/task_runners.py +++ b/src/prefect/testing/standard_test_suites/task_runners.py @@ -12,10 +12,12 @@ from prefect import flow, task from prefect.client.schemas import TaskRun from prefect.deprecated.data_documents import DataDocument +from prefect.logging import get_run_logger from prefect.orion.schemas.states import StateType -from prefect.states import State +from prefect.states import Crashed, State from prefect.task_runners import BaseTaskRunner, TaskConcurrencyType from prefect.testing.utilities import exceptions_equal +from prefect.utilities.annotations import allow_failure, quote class TaskRunnerStandardTestSuite(ABC): @@ -654,3 +656,46 @@ async def test_flow(): await test_flow() assert tmp_file.read_text() == "foo" + + def test_allow_failure(self, task_runner, caplog): + @task + def failing_task(): + raise ValueError("This is expected") + + @task + def depdendent_task(): + logger = get_run_logger() + logger.info("Dependent task still runs!") + return 1 + + @task + def another_dependent_task(): + logger = get_run_logger() + logger.info("Sub-dependent task still runs!") + return 1 + + @flow(task_runner=task_runner) + def test_flow(): + ft = failing_task.submit() + dt = depdendent_task.submit(wait_for=[allow_failure(ft)]) + adt = another_dependent_task.submit(wait_for=[dt]) + + with pytest.raises(ValueError, match="This is expected"): + test_flow() + assert len(caplog.records) == 2 + assert caplog.records[0].msg == "Dependent task still runs!" + assert caplog.records[1].msg == "Sub-dependent task still runs!" + + def test_passing_quoted_state(self, task_runner): + @task + def test_task(): + state = Crashed() + return quote(state) + + @flow(task_runner=task_runner) + def test_flow(): + return test_task() + + result = test_flow() + assert isinstance(result, quote) + assert isinstance(result.unquote(), State) From 31fac5081d5a874b1bcd526a70fd4caadd1a3627 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Fri, 6 Jan 2023 13:32:46 -0600 Subject: [PATCH 2/5] Expand check --- src/prefect/states.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/prefect/states.py b/src/prefect/states.py index 13ffa5ebbe51..50d90bf0e707 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -27,6 +27,7 @@ from prefect.orion.schemas.states import StateDetails, StateType from prefect.results import BaseResult, R, ResultFactory from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT +from prefect.utilities.annotations import BaseAnnotation from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible from prefect.utilities.collections import ensure_iterable @@ -391,8 +392,11 @@ def is_state_iterable(obj: Any) -> TypeGuard[Iterable[State]]: """ # We do not check for arbitary iterables because this is not intended to be used # for things like dictionaries, dataframes, or pydantic models - - if isinstance(obj, (list, set, tuple)) and obj: + if ( + not isinstance(obj, BaseAnnotation) + and isinstance(obj, (list, set, tuple)) + and obj + ): return all([is_state(o) for o in obj]) else: return False From dff9ccc53fc3efde5a7eaf844ee94e1d6e8faf54 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 5 Jan 2023 17:36:36 -0800 Subject: [PATCH 3/5] Fix recursion error --- src/prefect/utilities/annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prefect/utilities/annotations.py b/src/prefect/utilities/annotations.py index 8890821ab53e..1577b179d00a 100644 --- a/src/prefect/utilities/annotations.py +++ b/src/prefect/utilities/annotations.py @@ -31,7 +31,7 @@ class unmapped(BaseAnnotation[T]): def __getitem__(self, _) -> T: # Internally, this acts as an infinite array where all items are the same value - return self.unwrap() + return super().__getitem__(_) class allow_failure(BaseAnnotation[T]): From f68c720f1e0540d3ec324b5d8257f2b77e5d0b45 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 5 Jan 2023 17:49:20 -0800 Subject: [PATCH 4/5] Fix recursion error and expected res --- src/prefect/utilities/annotations.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/prefect/utilities/annotations.py b/src/prefect/utilities/annotations.py index 1577b179d00a..4c615913556f 100644 --- a/src/prefect/utilities/annotations.py +++ b/src/prefect/utilities/annotations.py @@ -13,7 +13,10 @@ def __init__(self, value: T): self.value = value def unwrap(self) -> T: - return self.value + # cannot simply return self.value due to recursion error in Python 3.7 + # also _asdict does not follow convention; it's not an internal method + # https://stackoverflow.com/a/26180604 + return self._asdict()["value"] def __eq__(self, other: object) -> bool: if not type(self) == type(other): @@ -31,7 +34,7 @@ class unmapped(BaseAnnotation[T]): def __getitem__(self, _) -> T: # Internally, this acts as an infinite array where all items are the same value - return super().__getitem__(_) + return self.unwrap() class allow_failure(BaseAnnotation[T]): From d2fa2e518aad9f712f10caf3df56c4c753a14be3 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Fri, 6 Jan 2023 11:56:35 -0800 Subject: [PATCH 5/5] Update src/prefect/testing/standard_test_suites/task_runners.py --- src/prefect/testing/standard_test_suites/task_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prefect/testing/standard_test_suites/task_runners.py b/src/prefect/testing/standard_test_suites/task_runners.py index 579b39bde550..06442be2c4ca 100644 --- a/src/prefect/testing/standard_test_suites/task_runners.py +++ b/src/prefect/testing/standard_test_suites/task_runners.py @@ -678,7 +678,7 @@ def another_dependent_task(): def test_flow(): ft = failing_task.submit() dt = depdendent_task.submit(wait_for=[allow_failure(ft)]) - adt = another_dependent_task.submit(wait_for=[dt]) + another_dependent_task.submit(wait_for=[dt]) with pytest.raises(ValueError, match="This is expected"): test_flow()