Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend visit_collection with support for annotations #7263

Merged
merged 15 commits into from
Jan 6, 2023
Merged
22 changes: 9 additions & 13 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
TaskConcurrencyType,
)
from prefect.tasks import Task
from prefect.utilities.annotations import Quote, allow_failure, unmapped
from prefect.utilities.annotations import allow_failure, quote, unmapped
from prefect.utilities.asyncutils import (
gather,
in_async_main_thread,
Expand Down Expand Up @@ -1015,9 +1015,6 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun
inputs = set()

def add_futures_and_states_to_inputs(obj):
if isinstance(obj, allow_failure):
obj = obj.unwrap()

if isinstance(obj, PrefectFuture):
run_async_from_worker_thread(obj._wait_for_submission)
inputs.add(TaskRunResult(id=obj.task_run.id))
Expand Down Expand Up @@ -1639,17 +1636,14 @@ async def resolve_inputs(
UpstreamTaskError: If any of the upstream states are not `COMPLETED`
"""

def resolve_input(expr):
def resolve_input(expr, context):
state = None
should_allow_failure = False

if isinstance(expr, allow_failure):
expr = expr.unwrap()
should_allow_failure = True
# Expressions inside quotes should not be modified
if isinstance(context.get("annotation"), quote):
return expr

if isinstance(expr, Quote):
return expr.unquote()
elif isinstance(expr, PrefectFuture):
if isinstance(expr, PrefectFuture):
state = run_async_from_worker_thread(expr._wait)
elif isinstance(expr, State):
state = expr
Expand All @@ -1659,7 +1653,7 @@ def resolve_input(expr):
# Do not allow uncompleted upstreams except failures when `allow_failure` has
# been used
if not state.is_completed() and not (
should_allow_failure and state.is_failed()
isinstance(context.get("annotation"), allow_failure) and state.is_failed()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. This will break when annotations are nested e.g. allow_failure(quote("foo")) — maybe I should make the context have a list or set of annotations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me - it seems that you could even change the key name to annotations without any backwards compatible breakage since the flow is re-parsed with each run

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to wait for this to be an actual issue so we can keep this simple and get it out there

):
raise UpstreamTaskError(
f"Upstream task run '{state.state_details.task_run_id}' did not reach a 'COMPLETED' state."
Expand All @@ -1674,6 +1668,8 @@ def resolve_input(expr):
visit_fn=resolve_input,
return_data=return_data,
max_depth=max_depth,
remove_annotations=True,
context={},
)


Expand Down
28 changes: 16 additions & 12 deletions src/prefect/utilities/annotations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC
from typing import Generic, TypeVar

Expand All @@ -15,11 +16,17 @@ def __init__(self, value: T):
def unwrap(self) -> T:
return self.value

def rewrap(self, value: T) -> "BaseAnnotation[T]":
return type(self)(value)

def __eq__(self, other: object) -> bool:
if not type(self) == type(other):
return False
return self.unwrap() == other.unwrap()

def __repr__(self) -> str:
return f"{type(self).__name__}({self.value!r})"


class unmapped(BaseAnnotation[T]):
"""
Expand Down Expand Up @@ -49,7 +56,7 @@ class allow_failure(BaseAnnotation[T]):
"""


class Quote(BaseAnnotation[T]):
class quote(BaseAnnotation[T]):
"""
Simple wrapper to mark an expression as a different type so it will not be coerced
by Prefect. For example, if you want to return a state from a flow without having
Expand All @@ -60,17 +67,14 @@ def unquote(self) -> T:
return self.unwrap()


def quote(expr: T) -> Quote[T]:
"""
Create a `Quote` object

Examples:
>>> from prefect.utilities.collections import quote
>>> x = quote(1)
>>> x.unquote()
1
"""
return Quote(expr)
# Backwards compatibility stub for `Quote` class
class Quote(quote):
def __init__(self, expr):
warnings.warn(
DeprecationWarning,
"Use of `Quote` is deprecated. Use `quote` instead.",
stacklevel=2,
)


class NotSet:
Expand Down
39 changes: 36 additions & 3 deletions src/prefect/utilities/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
Expand All @@ -28,8 +29,8 @@

import pydantic

# Moved to `prefect.utilities.annotations` but preserved here for compatibility
from prefect.utilities.annotations import Quote, quote # noqa
# Quote moved to `prefect.utilities.annotations` but preserved here for compatibility
from prefect.utilities.annotations import BaseAnnotation, Quote, quote # noqa

Check notice

Code scanning / CodeQL

Unused import

Import of 'Quote' is not used. Import of 'quote' is not used.


class AutoEnum(str, Enum):
Expand Down Expand Up @@ -212,6 +213,8 @@ def visit_collection(
visit_fn: Callable[[Any], Any],
return_data: bool = False,
max_depth: int = -1,
context: Optional[dict] = None,
remove_annotations: bool = False,
):
"""
This function visits every element of an arbitrary Python collection. If an element
Expand Down Expand Up @@ -244,6 +247,16 @@ def visit_collection(
descend to N layers deep. If set to any negative integer, no limit will be
enforced and recursion will continue until terminal items are reached. By
default, recursion is unlimited.
context: An optional dictionary. If passed, the context will be sent to each
call to the `visit_fn`. The context can be mutated by each visitor and will
be available for later visits to expressions at the given depth. Values
will not be available "up" a level from a given expression.

The context will be automatically populated with an 'annotation' key when
visiting collections within a `BaseAnnotation` type. This requires the
caller to pass `context={}` and will not be activated by default.
remove_annotations: If set, annotations will be replaced by their contents. By
default, annotations are preserved but their contents are visited.
"""

def visit_nested(expr):
Expand All @@ -252,11 +265,21 @@ def visit_nested(expr):
expr,
visit_fn=visit_fn,
return_data=return_data,
remove_annotations=remove_annotations,
max_depth=max_depth - 1,
# Copy the context on nested calls so it does not "propagate up"
context=context.copy() if context is not None else None,
)

def visit_expression(expr):
if context is not None:
return visit_fn(expr, context)
else:
return visit_fn(expr)

# Visit every expression
result = visit_fn(expr)
result = visit_expression(expr)

if return_data:
# Only mutate the expression while returning data, otherwise it could be null
expr = result
Expand All @@ -276,6 +299,16 @@ def visit_nested(expr):
# Do not attempt to recurse into mock objects
result = expr

elif isinstance(expr, BaseAnnotation):
if context is not None:
context["annotation"] = expr
value = visit_nested(expr.unwrap())

if remove_annotations:
result = value if return_data else None
else:
result = expr.rewrap(value) if return_data else None

elif typ in (list, tuple, set):
items = [visit_nested(o) for o in expr]
result = typ(items) if return_data else None
Expand Down
23 changes: 23 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,29 @@ def test_flow():
assert isinstance(result, ValueError)
assert "Fail task!" in str(result)

def test_downstream_receives_exception_in_collection_if_upstream_fails_and_allow_failure(
self,
):
@task
def fails():
raise ValueError("Fail task!")

@task
def bar(y):
return y

@flow
def test_flow():
f = fails.submit()
b = bar.submit(allow_failure([f, 1, 2]))
return b.result()

result = test_flow()
assert isinstance(result, list), f"Expected list; got {type(result)}"
assert isinstance(result[0], ValueError)
assert result[1:] == [1, 2]
assert "Fail task!" in str(result)


class TestTaskStates:
@pytest.mark.parametrize("error", [ValueError("Hello"), None])
Expand Down
45 changes: 45 additions & 0 deletions tests/utilities/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pydantic
import pytest

from prefect.utilities.annotations import BaseAnnotation, quote
from prefect.utilities.collections import (
AutoEnum,
dict_to_flatdict,
Expand All @@ -17,6 +18,10 @@
)


class ExampleAnnotation(BaseAnnotation):
pass


class Color(AutoEnum):
RED = AutoEnum.auto()
BLUE = AutoEnum.auto()
Expand Down Expand Up @@ -214,6 +219,7 @@ class TestVisitCollection:
(SimpleDataclass(x=1, y=2), SimpleDataclass(x=1, y=-2)),
(SimplePydantic(x=1, y=2), SimplePydantic(x=1, y=-2)),
(ExtraPydantic(x=1, y=2, z=3), ExtraPydantic(x=1, y=-2, z=3)),
(ExampleAnnotation(4), ExampleAnnotation(-4)),
],
)
def test_visit_collection_and_transform_data(self, inp, expected):
Expand All @@ -234,6 +240,7 @@ def test_visit_collection_and_transform_data(self, inp, expected):
(SimpleDataclass(x=1, y=2), {2}),
(SimplePydantic(x=1, y=2), {2}),
(ExtraPydantic(x=1, y=2, z=4), {2, 4}),
(ExampleAnnotation(4), {4}),
],
)
def test_visit_collection(self, inp, expected):
Expand Down Expand Up @@ -400,6 +407,44 @@ def test_visit_collection_max_depth(self, inp, depth, expected):
)
assert result == expected

def test_visit_collection_context(self):
foo = [1, 2, [3, 4], [5, [6, 7]], 8, 9]

def visit(expr, context):
if isinstance(expr, list):
context["depth"] += 1
return expr
else:
return expr + context["depth"]

result = visit_collection(foo, visit, context={"depth": 0}, return_data=True)
assert result == [2, 3, [5, 6], [7, [9, 10]], 9, 10]
zanieb marked this conversation as resolved.
Show resolved Hide resolved

def test_visit_collection_context_from_annotation(self):
foo = quote([1, 2, [3]])

def visit(expr, context):
# If we're not visiting the first expression...
if not isinstance(expr, quote):
assert isinstance(context.get("annotation"), quote)
return expr

result = visit_collection(foo, visit, context={}, return_data=True)
assert result == quote([1, 2, [3]])

def test_visit_collection_remove_annotations(self):
foo = quote([1, 2, quote([3])])

def visit(expr, context):
if isinstance(expr, int):
return expr + 1
return expr

result = visit_collection(
foo, visit, context={}, return_data=True, remove_annotations=True
)
assert result == [2, 3, [4]]


class TestRemoveKeys:
def test_remove_single_key(self):
Expand Down