Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename "dataset event" in context to use "outlet" #39397

Merged
merged 1 commit into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def dataset_with_extra_by_yield():
):

@task(outlets=[ds])
def dataset_with_extra_by_context(*, dataset_events=None):
dataset_events[ds].extra = {"hi": "bye"}
def dataset_with_extra_by_context(*, outlet_events=None):
outlet_events[ds].extra = {"hi": "bye"}

dataset_with_extra_by_context()

Expand All @@ -68,7 +68,7 @@ def dataset_with_extra_by_context(*, dataset_events=None):
):

def _dataset_with_extra_from_classic_operator_post_execute(context):
context["dataset_events"].extra = {"hi": "bye"}
context["outlet_events"].extra = {"hi": "bye"}

BashOperator(
task_id="dataset_with_extra_from_classic_operator",
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context, context_get_dataset_events
from airflow.utils.context import Context, context_get_outlet_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def pre_execute(self, context: Any):
return
ExecutionCallableRunner(
self._pre_execute_hook,
context_get_dataset_events(context),
context_get_outlet_events(context),
logger=self.log,
).run(context)

Expand All @@ -1304,7 +1304,7 @@ def post_execute(self, context: Any, result: Any = None):
return
ExecutionCallableRunner(
self._post_execute_hook,
context_get_dataset_events(context),
context_get_outlet_events(context),
logger=self.log,
).run(context, result)

Expand Down
12 changes: 6 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@
from airflow.utils.context import (
ConnectionAccessor,
Context,
DatasetEventAccessors,
InletEventsAccessors,
OutletEventAccessors,
VariableAccessor,
context_get_dataset_events,
context_get_outlet_events,
context_merge,
)
from airflow.utils.email import send_email
Expand Down Expand Up @@ -440,7 +440,7 @@ def _execute_callable(context: Context, **execute_callable_kwargs):

return ExecutionCallableRunner(
execute_callable,
context_get_dataset_events(context),
context_get_outlet_events(context),
logger=log,
).run(context=context, **execute_callable_kwargs)
except SystemExit as e:
Expand Down Expand Up @@ -799,7 +799,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti
"dag_run": dag_run,
"data_interval_end": timezone.coerce_datetime(data_interval.end),
"data_interval_start": timezone.coerce_datetime(data_interval.start),
"dataset_events": DatasetEventAccessors(),
"outlet_events": OutletEventAccessors(),
"ds": ds,
"ds_nodash": ds_nodash,
"execution_date": logical_date,
Expand Down Expand Up @@ -2639,7 +2639,7 @@ def _run_raw_task(
session.add(Log(self.state, self))
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(events=context["dataset_events"], session=session)
self._register_dataset_changes(events=context["outlet_events"], session=session)

session.commit()
if self.state == TaskInstanceState.SUCCESS:
Expand All @@ -2649,7 +2649,7 @@ def _run_raw_task(

return None

def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: Session) -> None:
def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None:
if TYPE_CHECKING:
assert self.task

Expand Down
4 changes: 2 additions & 2 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge
from airflow.utils.context import context_copy_partial, context_get_outlet_events, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
self._dataset_events = context_get_dataset_events(context)
self._dataset_events = context_get_outlet_events(context)

return_value = self.execute_callable()
if self.show_return_value_in_logs:
Expand Down
10 changes: 5 additions & 5 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from airflow.triggers.base import BaseTrigger
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context, DatasetEventAccessor, DatasetEventAccessors
from airflow.utils.context import Context, OutletEventAccessor, OutletEventAccessors
from airflow.utils.docs import get_docs_url
from airflow.utils.helpers import exactly_one
from airflow.utils.module_loading import import_string, qualname
Expand Down Expand Up @@ -536,12 +536,12 @@ def serialize(
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DatasetEventAccessors):
elif isinstance(var, OutletEventAccessors):
return cls._encode(
cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
type_=DAT.DATASET_EVENT_ACCESSORS,
)
elif isinstance(var, DatasetEventAccessor):
elif isinstance(var, OutletEventAccessor):
return cls._encode(
cls.serialize(var.extra, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.DATASET_EVENT_ACCESSOR,
Expand Down Expand Up @@ -693,11 +693,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
elif type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.DATASET_EVENT_ACCESSORS:
d = DatasetEventAccessors() # type: ignore[assignment]
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = cls.deserialize(var) # type: ignore[attr-defined]
return d
elif type_ == DAT.DATASET_EVENT_ACCESSOR:
return DatasetEventAccessor(extra=cls.deserialize(var))
return OutletEventAccessor(extra=cls.deserialize(var))
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
28 changes: 17 additions & 11 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"dag_run",
"data_interval_end",
"data_interval_start",
"dataset_events",
"ds",
"ds_nodash",
"execution_date",
Expand All @@ -77,6 +76,7 @@
"next_ds_nodash",
"next_execution_date",
"outlets",
"outlet_events",
"params",
"prev_data_interval_start_success",
"prev_data_interval_end_success",
Expand Down Expand Up @@ -157,27 +157,33 @@ def get(self, key: str, default_conn: Any = None) -> Any:


@attrs.define()
class DatasetEventAccessor:
"""Wrapper to access a DatasetEvent instance in template."""
class OutletEventAccessor:
"""Wrapper to access an outlet dataset event in template.

:meta private:
"""

extra: dict[str, Any]


class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
"""Lazy mapping of dataset event accessors."""
class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
"""Lazy mapping of outlet dataset event accessors.

:meta private:
"""

def __init__(self) -> None:
self._dict: dict[str, DatasetEventAccessor] = {}
self._dict: dict[str, OutletEventAccessor] = {}

def __iter__(self) -> Iterator[str]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
def __getitem__(self, key: str | Dataset) -> OutletEventAccessor:
if (uri := coerce_to_uri(key)) not in self._dict:
self._dict[uri] = DatasetEventAccessor({})
self._dict[uri] = OutletEventAccessor({})
return self._dict[uri]


Expand Down Expand Up @@ -448,8 +454,8 @@ def _create_value(k: str, v: Any) -> Any:
return {k: _create_value(k, v) for k, v in source._context.items()}


def context_get_dataset_events(context: Context) -> DatasetEventAccessors:
def context_get_outlet_events(context: Context) -> OutletEventAccessors:
try:
return context["dataset_events"]
return context["outlet_events"]
except KeyError:
return DatasetEventAccessors()
return OutletEventAccessors()
10 changes: 5 additions & 5 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ class VariableAccessor:
class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

class DatasetEventAccessor:
class OutletEventAccessor:
def __init__(self, *, extra: dict[str, Any]) -> None: ...
extra: dict[str, Any]

class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: ...
def __getitem__(self, key: str | Dataset) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[DatasetEvent]):
@overload
Expand All @@ -89,7 +89,7 @@ class Context(TypedDict, total=False):
dag_run: DagRun | DagRunPydantic
data_interval_end: DateTime
data_interval_start: DateTime
dataset_events: DatasetEventAccessors
outlet_events: OutletEventAccessors
ds: str
ds_nodash: str
exception: BaseException | str | None
Expand Down Expand Up @@ -143,4 +143,4 @@ def context_merge(context: Context, **kwargs: Any) -> None: ...
def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: ...
def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
def context_get_dataset_events(context: Context) -> DatasetEventAccessors: ...
def context_get_outlet_events(context: Context) -> OutletEventAccessors: ...
8 changes: 4 additions & 4 deletions airflow/utils/operator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from airflow.utils.context import Context, lazy_mapping_from_context

if TYPE_CHECKING:
from airflow.utils.context import DatasetEventAccessors
from airflow.utils.context import OutletEventAccessors

R = TypeVar("R")

Expand Down Expand Up @@ -232,12 +232,12 @@ class ExecutionCallableRunner:
def __init__(
self,
func: Callable,
dataset_events: DatasetEventAccessors,
outlet_events: OutletEventAccessors,
*,
logger: logging.Logger | None,
) -> None:
self.func = func
self.dataset_events = dataset_events
self.outlet_events = outlet_events
self.logger = logger or logging.getLogger(__name__)

def run(self, *args, **kwargs) -> Any:
Expand All @@ -257,7 +257,7 @@ def _run():

for metadata in _run():
if isinstance(metadata, Metadata):
self.dataset_events[metadata.uri].extra.update(metadata.extra)
self.outlet_events[metadata.uri].extra.update(metadata.extra)
continue
self.logger.warning("Ignoring unknown data of %r received from task", type(metadata))
if self.logger.isEnabledFor(logging.DEBUG):
Expand Down
6 changes: 3 additions & 3 deletions docs/apache-airflow/authoring-and-scheduling/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@ Airflow automatically collects all yielded metadata, and populates dataset event

This can also be done in classic operators. The best way is to subclass the operator and override ``execute``. Alternatively, extras can also be added in a task's ``pre_execute`` or ``post_execute`` hook. If you choose to use hooks, however, remember that they are not rerun when a task is retried, and may cause the extra information to not match actual data in certain scenarios.

Another way to achieve the same is by accessing ``dataset_events`` in a task's execution context directly:
Another way to achieve the same is by accessing ``outlet_events`` in a task's execution context directly:

.. code-block:: python

@task(outlets=[example_s3_dataset])
def write_to_s3(*, dataset_events):
dataset_events[example_s3_dataset].extras = {"row_count": len(df)}
def write_to_s3(*, outlet_events):
outlet_events[example_s3_dataset].extras = {"row_count": len(df)}

There's minimal magic here---Airflow simply writes the yielded values to the exact same accessor. This also works in classic operators, including ``execute``, ``pre_execute``, and ``post_execute``.

Expand Down
4 changes: 2 additions & 2 deletions docs/apache-airflow/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Variable Type Description
``{{ inlets }}`` list List of inlets declared on the task.
``{{ inlet_events }}`` dict[str, ...] Access past events of inlet datasets. See :doc:`Datasets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ outlets }}`` list List of outlets declared on the task.
``{{ outlet_events }}`` dict[str, ...] | Accessors to attach information to dataset events that will be emitted by the current task.
| See :doc:`Datasets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs <core-concepts/dags>`.
``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators`
``{{ macros }}`` | A reference to the macros package. See Macros_ below.
Expand All @@ -75,8 +77,6 @@ Variable Type Description
``{{ var.value }}`` Airflow variables. See `Airflow Variables in Templates`_ below.
``{{ var.json }}`` Airflow variables. See `Airflow Variables in Templates`_ below.
``{{ conn }}`` Airflow connections. See `Airflow Connections in Templates`_ below.
``{{ dataset_events }}`` dict[str, ...] | Accessors to attach information to dataset events that will be emitted by the current task.
| See :doc:`Datasets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ task_instance_key_str }}`` str | A unique, human-readable key to the task instance. The format is
| ``{dag_id}__{task_id}__{ds_nodash}``.
``{{ conf }}`` AirflowConfigParser | The full configuration object representing the content of your
Expand Down
16 changes: 8 additions & 8 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,13 +2321,13 @@ def test_outlet_dataset_extra(self, dag_maker, session):
with dag_maker(schedule=None, session=session) as dag:

@task(outlets=Dataset("test_outlet_dataset_extra_1"))
def write1(*, dataset_events):
dataset_events["test_outlet_dataset_extra_1"].extra = {"foo": "bar"}
def write1(*, outlet_events):
outlet_events["test_outlet_dataset_extra_1"].extra = {"foo": "bar"}

write1()

def _write2_post_execute(context, _):
context["dataset_events"]["test_outlet_dataset_extra_2"].extra = {"x": 1}
context["outlet_events"]["test_outlet_dataset_extra_2"].extra = {"x": 1}

BashOperator(
task_id="write2",
Expand Down Expand Up @@ -2362,9 +2362,9 @@ def test_outlet_dataset_extra_ignore_different(self, dag_maker, session):
with dag_maker(schedule=None, session=session):

@task(outlets=Dataset("test_outlet_dataset_extra"))
def write(*, dataset_events):
dataset_events["test_outlet_dataset_extra"].extra = {"one": 1}
dataset_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped.
def write(*, outlet_events):
outlet_events["test_outlet_dataset_extra"].extra = {"one": 1}
outlet_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped.

write()

Expand Down Expand Up @@ -2434,8 +2434,8 @@ def test_inlet_dataset_extra(self, dag_maker, session):
with dag_maker(schedule=None, session=session):

@task(outlets=Dataset("test_inlet_dataset_extra"))
def write(*, ti, dataset_events):
dataset_events["test_inlet_dataset_extra"].extra = {"from": ti.task_id}
def write(*, ti, outlet_events):
outlet_events["test_inlet_dataset_extra"].extra = {"from": ti.task_id}

@task(inlets=Dataset("test_inlet_dataset_extra"))
def read(*, inlet_events):
Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance):
"ti",
"var", # Accessor for Variable; var->json and var->value.
"conn", # Accessor for Connection.
"dataset_events", # Accessor for outlet DatasetEvent.
"inlet_events", # Accessor for inlet DatasetEvent.
"outlet_events", # Accessor for outlet DatasetEvent.
]

ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import DatasetEventAccessors
from airflow.utils.context import OutletEventAccessors
from airflow.utils.operator_resources import Resources
from airflow.utils.pydantic import BaseModel
from airflow.utils.state import DagRunState, State
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_serialized_mapped_operator_unmap(dag_maker):

def test_ser_of_dataset_event_accessor():
# todo: (Airflow 3.0) we should force reserialization on upgrade
d = DatasetEventAccessors()
d = OutletEventAccessors()
d["hi"].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict?
d["yo"].extra = {"this": "that", "the": "other"}
ser = BaseSerialization.serialize(var=d)
Expand Down