From 3b575058f790d85802a01518303da308353836c6 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 24 Jul 2025 15:29:11 +0800 Subject: [PATCH] Remove MappedOperator inheritance Now MappedOperator in core is entirely detached from the SDK. I decided to remove the start_from_trigger code entirely (the code path is already entirely unreachable in Airflow 3) since it causes too many Mypy errors. We can always get the implementation back from Git history when we re-implement the feature with the new task runner. --- .../api_fastapi/core_api/services/ui/grid.py | 2 +- airflow-core/src/airflow/models/dagrun.py | 34 +- .../src/airflow/models/expandinput.py | 51 ++- .../src/airflow/models/mappedoperator.py | 395 +++++++++++++++--- .../src/airflow/models/referencemixin.py | 53 +++ .../src/airflow/models/serialized_dag.py | 2 +- .../src/airflow/models/taskinstance.py | 11 +- airflow-core/src/airflow/models/xcom_arg.py | 65 ++- .../serialization/serialized_objects.py | 71 ++-- .../ti_deps/deps/mapped_task_upstream_dep.py | 18 +- .../airflow/ti_deps/deps/prev_dagrun_dep.py | 9 +- .../airflow/ti_deps/deps/trigger_rule_dep.py | 6 +- .../serialization/test_dag_serialization.py | 8 +- .../providers/celery/version_compat.py | 2 +- .../providers/google/version_compat.py | 2 +- .../providers/openlineage/version_compat.py | 4 +- .../providers/standard/utils/skipmixin.py | 2 +- .../unit/standard/decorators/test_python.py | 2 +- .../unit/standard/utils/test_skipmixin.py | 3 +- .../check_base_operator_partial_arguments.py | 5 - task-sdk/src/airflow/sdk/bases/decorator.py | 6 - task-sdk/src/airflow/sdk/bases/operator.py | 11 +- .../definitions/_internal/abstractoperator.py | 15 +- task-sdk/src/airflow/sdk/definitions/dag.py | 3 +- .../airflow/sdk/definitions/mappedoperator.py | 87 ++-- .../src/airflow/sdk/definitions/taskgroup.py | 13 +- .../src/airflow/sdk/definitions/xcom_arg.py | 53 ++- 27 files changed, 652 insertions(+), 281 deletions(-) create mode 100644 airflow-core/src/airflow/models/referencemixin.py diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index b8e569f877e70..8b209fb324cd8 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -23,8 +23,8 @@ import structlog from airflow.api_fastapi.common.parameters import state_priority +from airflow.models.mappedoperator import MappedOperator from airflow.models.taskmap import TaskMap -from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter from airflow.serialization.serialized_objects import SerializedBaseOperator diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 9eed649f0cad3..f112a1ebba98e 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -102,13 +102,11 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.types import ArgNotSet - Operator: TypeAlias = MappedOperator | SerializedBaseOperator - CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) - - AttributeValueType = ( + AttributeValueType: TypeAlias = ( str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] ) + Operator: TypeAlias = MappedOperator | SerializedBaseOperator RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$" @@ -1483,15 +1481,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: If the ti does not need expansion, either because the task is not mapped, or has already been expanded, *None* is returned. """ + from airflow.models.mappedoperator import is_mapped + if TYPE_CHECKING: assert ti.task if ti.map_index >= 0: # Already expanded, we're good. return None - from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator - - if isinstance(ti.task, TaskSDKMappedOperator): + if is_mapped(ti.task): # If we get here, it could be that we are moving from non-mapped to mapped # after task instance clearing or this ti is not yet expanded. Safe to clear # the db references. @@ -1510,7 +1508,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: revised_map_index_task_ids: set[str] = set() for schedulable in itertools.chain(schedulable_tis, additional_tis): if TYPE_CHECKING: - assert isinstance(schedulable.task, SerializedBaseOperator) + assert isinstance(schedulable.task, Operator) old_state = schedulable.state if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state @@ -1995,25 +1993,23 @@ def schedule_tis( empty_ti_ids: list[str] = [] schedulable_ti_ids: list[str] = [] for ti in schedulable_tis: + task = ti.task if TYPE_CHECKING: - assert isinstance(ti.task, SerializedBaseOperator) + assert isinstance(task, Operator) if ( - ti.task.inherits_from_empty_operator - and not ti.task.on_execute_callback - and not ti.task.on_success_callback - and not ti.task.outlets - and not ti.task.inlets + task.inherits_from_empty_operator + and not task.on_execute_callback + and not task.on_success_callback + and not task.outlets + and not task.inlets ): empty_ti_ids.append(ti.id) # check "start_trigger_args" to see whether the operator supports start execution from triggerer # if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer # this task. # if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker - elif ti.task.start_trigger_args is not None: - context = ti.get_template_context() - start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session) - - if start_from_trigger: + elif task.start_trigger_args is not None: + if task.expand_start_from_trigger(context=ti.get_template_context()): ti.start_date = timezone.utcnow() if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 diff --git a/airflow-core/src/airflow/models/expandinput.py b/airflow-core/src/airflow/models/expandinput.py index 6aa44316f6e77..8487022b05e75 100644 --- a/airflow-core/src/airflow/models/expandinput.py +++ b/airflow-core/src/airflow/models/expandinput.py @@ -19,18 +19,11 @@ import functools import operator -from collections.abc import Iterable, Sized +from collections.abc import Iterable, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, ClassVar import attrs -if TYPE_CHECKING: - from typing import TypeGuard - - from sqlalchemy.orm import Session - - from airflow.models.xcom_arg import SchedulerXComArg - from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, ListOfDictsExpandInput, @@ -41,6 +34,18 @@ is_mappable, ) +if TYPE_CHECKING: + from typing import TypeAlias, TypeGuard + + from sqlalchemy.orm import Session + + from airflow.models.mappedoperator import MappedOperator + from airflow.models.xcom_arg import SchedulerXComArg + from airflow.serialization.serialized_objects import SerializedBaseOperator + + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + + __all__ = [ "DictOfListsExpandInput", "ListOfDictsExpandInput", @@ -111,6 +116,26 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.referencemixin import ReferenceMixin + + for x in self.value.values(): + if isinstance(x, ReferenceMixin): + yield from x.iter_references() + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + @attrs.define class SchedulerListOfDictsExpandInput: @@ -133,6 +158,16 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: raise NotFullyPopulated({"expand_kwargs() argument"}) return length + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.referencemixin import ReferenceMixin + + if isinstance(self.value, ReferenceMixin): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, ReferenceMixin): + yield from x.iter_references() + _EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = { "dict-of-lists": SchedulerDictOfListsExpandInput, diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index 7dbd3309fd07e..5e5fe9975fdde 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -15,50 +15,65 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import functools import operator -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard import attrs +import methodtools import structlog from sqlalchemy.orm import Session from airflow.exceptions import AirflowException -from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.definitions._internal.abstractoperator import NotMapped +from airflow.sdk import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions._internal.abstractoperator import ( + DEFAULT_EXECUTOR, + DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + DEFAULT_OWNER, + DEFAULT_POOL_NAME, + DEFAULT_POOL_SLOTS, + DEFAULT_PRIORITY_WEIGHT, + DEFAULT_QUEUE, + DEFAULT_RETRIES, + DEFAULT_RETRY_DELAY, + DEFAULT_TRIGGER_RULE, + DEFAULT_WEIGHT_RULE, + NotMapped, + TaskStateChangeCallbackAttrType, +) +from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy if TYPE_CHECKING: - from collections.abc import Iterator + import datetime + from collections.abc import Collection, Iterator, Sequence + + import pendulum from airflow.models import TaskInstance from airflow.models.dag import DAG as SchedulerDAG + from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk import BaseOperatorLink - from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.param import ParamsDict from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs + from airflow.utils.operator_resources import Resources + from airflow.utils.trigger_rule import TriggerRule -log = structlog.get_logger(__name__) + Operator: TypeAlias = "SerializedBaseOperator | MappedOperator" +log = structlog.get_logger(__name__) -def _prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: - """ - Ensure *kwargs1* and *kwargs2* do not contain common keys. - :raises TypeError: If common keys are found. - """ - duplicated_keys = set(kwargs1).intersection(kwargs2) - if not duplicated_keys: - return - if len(duplicated_keys) == 1: - raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") - duplicated_keys_display = ", ".join(sorted(duplicated_keys)) - raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") +def is_mapped(task: Operator) -> TypeGuard[MappedOperator]: + return task.is_mapped @attrs.define( @@ -73,49 +88,229 @@ def _prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, getstate_setstate=False, repr=False, ) -class MappedOperator(TaskSDKMappedOperator): +# TODO (GH-52141): Duplicate DAGNode in the scheduler. +class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG.""" + operator_class: dict[str, Any] + partial_kwargs: dict[str, Any] = attrs.field(init=False, factory=dict) + + # Needed for serialization. + task_id: str + params: ParamsDict | dict = attrs.field(init=False, factory=dict) + operator_extra_links: Collection[BaseOperatorLink] + template_ext: Sequence[str] + template_fields: Collection[str] + template_fields_renderers: dict[str, str] + ui_color: str + ui_fgcolor: str + _is_empty: bool = attrs.field(alias="is_empty", init=False, default=False) + _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") + _is_sensor: bool = attrs.field(alias="is_sensor", default=False) + _task_module: str + _task_type: str + _operator_name: str + start_trigger_args: StartTriggerArgs | None + start_from_trigger: bool + _needs_expansion: bool = True + + dag: SchedulerDAG = attrs.field(init=False) + task_group: TaskGroup = attrs.field(init=False) + start_date: pendulum.DateTime | None = attrs.field(init=False, default=None) + end_date: pendulum.DateTime | None = attrs.field(init=False, default=None) + upstream_task_ids: set[str] = attrs.field(factory=set, init=False) + downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + + _disallow_kwargs_override: bool + """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. + + If *False*, values from ``expand_input`` under duplicate keys override those + under corresponding keys in ``partial_kwargs``. + """ + + _expand_input_attr: str + """Where to get kwargs to calculate expansion length against. + + This should be a name to call ``getattr()`` on. + """ + deps: frozenset[BaseTIDep] = attrs.field(init=False, default=DEFAULT_OPERATOR_DEPS) - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - """ - Get the start_from_trigger value of the current abstract operator. + is_mapped: ClassVar[bool] = True - MappedOperator uses this to unmap start_from_trigger to decide whether to start the task - execution directly from triggerer. + @property + def node_id(self) -> str: + return self.task_id - :meta private: - """ - if self.partial_kwargs.get("start_from_trigger", self.start_from_trigger): - log.warning( - "Starting a mapped task from triggerer is currently unsupported", - task_id=self.task_id, - dag_id=self.dag_id, - ) + @property + def roots(self) -> Sequence[DAGNode]: + """Required by DAGNode.""" + return [self] - # This is intentional. start_from_trigger does not work correctly with - # sdk-db separation yet, so it is disabled unconditionally for now. - # TODO: TaskSDK: Implement this properly. - return False + @property + def leaves(self) -> Sequence[DAGNode]: + """Required by DAGNode.""" + return [self] - # start_from_trigger only makes sense when start_trigger_args exists. - if not self.start_trigger_args: - return False + # TODO (GH-52141): Review if any of the properties below are used in the + # SDK and the scheduler, and remove those not needed. - mapped_kwargs, _ = self._expand_mapped_kwargs(context) - if self._disallow_kwargs_override: - _prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) + @property + def task_type(self) -> str: + """Implementing Operator.""" + return self._task_type + + @property + def operator_name(self) -> str: + return self._operator_name - # Ordering is significant; mapped kwargs should override partial ones. - return mapped_kwargs.get( - "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) + @property + def task_display_name(self) -> str: + return self.partial_kwargs.get("task_display_name") or self.task_id + + @property + def doc_md(self) -> str | None: + return self.partial_kwargs.get("doc_md") + + @property + def inherits_from_empty_operator(self) -> bool: + """Implementing an empty Operator.""" + return self._is_empty + + @property + def inherits_from_skipmixin(self) -> bool: + return self._can_skip_downstream + + @property + def owner(self) -> str: + return self.partial_kwargs.get("owner", DEFAULT_OWNER) + + @property + def trigger_rule(self) -> TriggerRule: + return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + + @property + def is_setup(self) -> bool: + return bool(self.partial_kwargs.get("is_setup")) + + @property + def is_teardown(self) -> bool: + return bool(self.partial_kwargs.get("is_teardown")) + + @property + def depends_on_past(self) -> bool: + return bool(self.partial_kwargs.get("depends_on_past")) + + @property + def ignore_first_depends_on_past(self) -> bool: + value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) + return bool(value) + + @property + def wait_for_downstream(self) -> bool: + return bool(self.partial_kwargs.get("wait_for_downstream")) + + @property + def retries(self) -> int: + return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + + @property + def queue(self) -> str: + return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + + @property + def pool(self) -> str: + return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) + + @property + def pool_slots(self) -> int: + return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) + + @property + def resources(self) -> Resources | None: + return self.partial_kwargs.get("resources") + + @property + def max_active_tis_per_dag(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dag") + + @property + def max_active_tis_per_dagrun(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dagrun") + + @property + def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_execute_callback") or [] + + @property + def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_failure_callback") or [] + + @property + def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_retry_callback") or [] + + @property + def on_success_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_success_callback") or [] + + @property + def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_skipped_callback") or [] + + @property + def run_as_user(self) -> str | None: + return self.partial_kwargs.get("run_as_user") + + @property + def priority_weight(self) -> int: + return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) + + @property + def retry_delay(self) -> datetime.timedelta: + return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) + + @property + def retry_exponential_backoff(self) -> bool: + return bool(self.partial_kwargs.get("retry_exponential_backoff")) + + @property + def weight_rule(self) -> PriorityWeightStrategy: + return validate_and_load_priority_weight_strategy( + self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) ) + @property + def executor(self) -> str | None: + return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) + + @property + def executor_config(self) -> dict: + return self.partial_kwargs.get("executor_config", {}) + + @property + def execution_timeout(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("execution_timeout") + + @property + def inlets(self) -> list[Any]: + return self.partial_kwargs.get("inlets", []) + + @property + def outlets(self) -> list[Any]: + return self.partial_kwargs.get("outlets", []) + + @property + def on_failure_fail_dagrun(self) -> bool: + return bool(self.partial_kwargs.get("on_failure_fail_dagrun")) + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, v) -> None: + self.partial_kwargs["on_failure_fail_dagrun"] = bool(v) + + def get_serialized_fields(self): + return TaskSDKMappedOperator.get_serialized_fields() + @functools.cached_property def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]: """Returns dictionary of all extra links for the operator.""" @@ -167,6 +362,104 @@ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None: return None return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives MappedOperator + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def _get_specified_expand_input(self) -> SchedulerExpandInput: + """Input received from the expand call on the operator.""" + return getattr(self, self._expand_input_attr) + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups this task belongs to. + + Groups are returned from the innermost to the outmost. + + :meta private: + """ + if (group := self.task_group) is None: + return + yield from group.iter_mapped_task_groups() + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: + """ + Get the mapped task group "closest" to this task in the DAG. + + :meta private: + """ + return next(self.iter_mapped_task_groups(), None) + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def get_needs_expansion(self) -> bool: + """ + Return true if the task is MappedOperator or is in a mapped task group. + + :meta private: + """ + return self._needs_expansion + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + @methodtools.lru_cache(maxsize=1) + def get_parse_time_mapped_ti_count(self) -> int: + current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() + + def _get_parent_count() -> int: + if (group := self.get_closest_mapped_task_group()) is None: + raise NotMapped() + return group.get_parse_time_mapped_ti_count() + + try: + parent_count = _get_parent_count() + except NotMapped: + return current_count + return parent_count * current_count + + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this task for task mapping.""" + from airflow.models.xcom_arg import SchedulerXComArg + + for op, _ in SchedulerXComArg.iter_xcom_references(self._get_specified_expand_input()): + yield op + + def expand_start_from_trigger(self, *, context: Context) -> bool: + if not self.partial_kwargs.get("start_from_trigger", self.start_from_trigger): + return False + # TODO (GH-52141): Implement this. + log.warning( + "Starting a mapped task from triggerer is currently unsupported", + task_id=self.task_id, + dag_id=self.dag_id, + ) + return False + + # TODO (GH-52141): Move the implementation in SDK MappedOperator here. + def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: + raise NotImplementedError + + def unmap(self, resolve: None) -> SerializedBaseOperator: + """ + Get the "normal" Operator after applying the current mapping. + + The *resolve* argument is never used and should always be *None*. It + exists only to match the signature of the non-serialized implementation. + + The return value is a SerializedBaseOperator that "looks like" the + actual unmapping result. + + :meta private: + """ + # After a mapped operator is serialized, there's no real way to actually + # unmap it since we've lost access to the underlying operator class. + # This tries its best to simply "forward" all the attributes on this + # mapped operator to a new SerializedBaseOperator instance. + sop = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) + for partial_attr, value in self.partial_kwargs.items(): + setattr(sop, partial_attr, value) + SerializedBaseOperator.populate_operator(sop, self.operator_class) + if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. + SerializedBaseOperator.set_task_dag_references(sop, self.dag) + return sop + @functools.singledispatch def get_mapped_ti_count(task: DAGNode, run_id: str, *, session: Session) -> int: @@ -192,8 +485,6 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id: str, *, session: Ses from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef exp_input = task._get_specified_expand_input() - if isinstance(exp_input, _ExpandInputRef): - exp_input = exp_input.deref(task.dag) # TODO (GH-52141): 'task' here should be scheduler-bound and returns scheduler expand input. if not hasattr(exp_input, "get_total_map_length"): if TYPE_CHECKING: diff --git a/airflow-core/src/airflow/models/referencemixin.py b/airflow-core/src/airflow/models/referencemixin.py new file mode 100644 index 0000000000000..19a775417f865 --- /dev/null +++ b/airflow-core/src/airflow/models/referencemixin.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import TypeAlias + + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + + +@runtime_checkable +class ReferenceMixin(Protocol): + """ + Mixin for things that references a task. + + This should be implemented by things that reference operators and use them + to lazily resolve values at runtime. The most prominent examples are XCom + references (XComArg). + + This is a partial interface to the SDK's ResolveMixin with the resolve() + method removed since the scheduler should not need to resolve the reference. + """ + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + """ + Find underlying XCom references this contains. + + This is used by the DAG parser to recursively find task dependencies. + + :meta private: + """ + raise NotImplementedError diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 5ea0947f9f07a..6444d5720de20 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -295,7 +295,7 @@ class SerializedDagModel(Base): dag_runs = relationship( DagRun, - primaryjoin=dag_id == foreign(DagRun.dag_id), + primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore[has-type] backref=backref("serialized_dag", uselist=False, innerjoin=True), ) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index eac003c816cb6..8b690b417e1ef 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -120,9 +120,9 @@ from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator + from airflow.sdk import DAG from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.sdk.types import RuntimeTaskInstanceProtocol from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -1430,7 +1430,7 @@ def update_heartbeat(self): @provide_session def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None: """ - Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. + Mark the task as deferred and sets up the trigger to resume it. :meta: private """ @@ -1590,7 +1590,8 @@ def fetch_handle_failure_context( ti.clear_next_method_args() context = None - # In extreme cases (task instance heartbeat timeout in case of dag with parse error) we might _not_ have a Task. + # In extreme cases (task instance heartbeat timeout in case of dag with + # parse error) we might _not_ have a Task. if getattr(ti, "task", None): context = ti.get_template_context(session) @@ -1853,7 +1854,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "_upstream_map_indexes", { upstream.task_id: self.get_relevant_upstream_map_indexes( - cast("Operator", upstream), + upstream, expanded_ti_count, session=session, ) @@ -2288,7 +2289,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: """Whether given operator is *further* mapped inside a task group.""" - from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup if isinstance(operator, MappedOperator): diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index 1109f03bb1f99..78021e5043123 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -17,20 +17,18 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias, cast import attrs from sqlalchemy import func, or_, select from sqlalchemy.orm import Session +from airflow.models.referencemixin import ReferenceMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.sdk.definitions._internal.types import ArgNotSet -from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.xcom_arg import ( - XComArg, -) +from airflow.sdk.definitions.xcom_arg import XComArg from airflow.utils.db import exists_query from airflow.utils.state import State from airflow.utils.types import NOTSET @@ -39,11 +37,13 @@ if TYPE_CHECKING: from airflow.models.dag import DAG as SchedulerDAG - from airflow.models.operator import Operator + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.typing_compat import Self + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + -@attrs.define class SchedulerXComArg: @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: @@ -57,7 +57,33 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: dicts to the correct ``_deserialize`` information, so this function does not need to validate whether the incoming data contains correct keys. """ - raise NotImplementedError() + raise NotImplementedError("This class should not be instantiated directly") + + @classmethod + def iter_xcom_references(cls, arg: Any) -> Iterator[tuple[Operator, str]]: + """ + Return XCom references in an arbitrary value. + + Recursively traverse ``arg`` and look for XComArg instances in any + collection objects, and instances with ``template_fields`` set. + """ + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + if isinstance(arg, ReferenceMixin): + yield from arg.iter_references() + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + yield from cls.iter_xcom_references(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + yield from cls.iter_xcom_references(elem) + elif isinstance(arg, (MappedOperator, SerializedBaseOperator)): + for attr in arg.template_fields: + yield from cls.iter_xcom_references(getattr(arg, attr)) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + raise NotImplementedError("This class should not be instantiated directly") @attrs.define @@ -67,7 +93,11 @@ class SchedulerPlainXComArg(SchedulerXComArg): @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: - return cls(dag.get_task(data["task_id"]), data["key"]) + # TODO (GH-52141): SchedulerDAG should return scheduler operator instead. + return cls(cast("Operator", dag.get_task(data["task_id"])), data["key"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield self.operator, self.key @attrs.define @@ -81,6 +111,9 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: # in the UI, and displaying a function object is useless. return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + @attrs.define class SchedulerConcatXComArg(SchedulerXComArg): @@ -90,6 +123,10 @@ class SchedulerConcatXComArg(SchedulerXComArg): def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + @attrs.define class SchedulerZipXComArg(SchedulerXComArg): @@ -103,6 +140,10 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: fillvalue=data.get("fillvalue", NOTSET), ) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + @singledispatch def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None: @@ -112,15 +153,15 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Ses @get_task_map_length.register def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): + from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id - is_mapped = xcom_arg.operator.is_mapped or isinstance(xcom_arg.operator, MappedOperator) - if is_mapped: + if is_mapped(xcom_arg.operator): unfinished_ti_exists = exists_query( TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id, diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index a8a766e553b31..3ea6a8c685c05 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -46,9 +46,7 @@ from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.models.connection import Connection from airflow.models.dag import DAG, _get_model_data_interval -from airflow.models.expandinput import ( - create_expand_input, -) +from airflow.models.expandinput import create_expand_input from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg @@ -101,14 +99,13 @@ if TYPE_CHECKING: from inspect import Parameter - from sqlalchemy.orm import Session - from airflow.models import DagRun from airflow.models.expandinput import SchedulerExpandInput from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk import DAG as SdkDag, BaseOperatorLink from airflow.serialization.json_schema import Validator + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.triggers.base import BaseEventTrigger from airflow.typing_compat import Self @@ -125,7 +122,7 @@ SchedulerOperator: TypeAlias = "SchedulerMappedOperator | SerializedBaseOperator" SdkOperator: TypeAlias = BaseOperator | MappedOperator -DEFAULT_OPERATOR_DEPS = frozenset( +DEFAULT_OPERATOR_DEPS: frozenset[BaseTIDep] = frozenset( ( NotInRetryPeriodDep(), PrevDagrunDep(), @@ -611,12 +608,12 @@ class BaseSerialization: SERIALIZER_VERSION = 2 @classmethod - def to_json(cls, var: DAG | SerializedBaseOperator | dict | list | set | tuple) -> str: + def to_json(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> str: """Stringify DAGs and operators contained by var and returns a JSON string of var.""" return json.dumps(cls.to_dict(var), ensure_ascii=True) @classmethod - def to_dict(cls, var: DAG | SerializedBaseOperator | dict | list | set | tuple) -> dict: + def to_dict(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> dict: """Stringify DAGs and operators contained by var and returns a dict of var.""" # Don't call on this class directly - only SerializedDAG or # SerializedBaseOperator should be used as the "entrypoint" @@ -671,8 +668,8 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( cls, - # TODO (GH-52141): When can we remove SerializedBaseOperator here? - object_to_serialize: BaseOperator | MappedOperator | SerializedBaseOperator | SdkDag, + # TODO (GH-52141): When can we remove scheduler constructs here? + object_to_serialize: SdkOperator | SchedulerOperator | SdkDag | DAG, decorated_fields: set, ) -> dict[str, Any]: """Serialize an object to JSON.""" @@ -720,6 +717,8 @@ def serialize( :meta private: """ + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. if isinstance(var, enum.Enum): @@ -759,9 +758,9 @@ def serialize( return cls._encode(DeadlineAlert.serialize_deadline_alert(var), type_=DAT.DEADLINE_ALERT) elif isinstance(var, Resources): return var.to_dict() - elif isinstance(var, MappedOperator): + elif isinstance(var, (MappedOperator, SchedulerMappedOperator)): return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP) - elif isinstance(var, BaseOperator): + elif isinstance(var, (BaseOperator, SerializedBaseOperator)): var._needs_expansion = var.get_needs_expansion() return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP) elif isinstance(var, cls._datetime_types): @@ -1252,6 +1251,13 @@ def __init__( self.deps = DEFAULT_OPERATOR_DEPS self._operator_name: str | None = None + def __eq__(self, other: Any) -> bool: + if not isinstance(other, (SerializedBaseOperator, BaseOperator)): + return NotImplemented + return self.task_type == other.task_type and all( + getattr(self, c, None) == getattr(other, c, None) for c in BaseOperator._comps + ) + @property def node_id(self) -> str: return self.task_id @@ -1352,7 +1358,7 @@ def __getattr__(self, name): raise AttributeError(f"'{self.task_type}' object has no attribute '{name}'") @classmethod - def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: + def serialize_mapped_operator(cls, op: MappedOperator | SchedulerMappedOperator) -> dict[str, Any]: serialized_op = cls._serialize_node(op) # Handle expand_input and op_kwargs_expand_input. expansion_kwargs = op._get_specified_expand_input() @@ -1378,11 +1384,11 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: return serialized_op @classmethod - def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: + def serialize_operator(cls, op: SdkOperator | SchedulerOperator) -> dict[str, Any]: return cls._serialize_node(op) @classmethod - def _serialize_node(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: + def _serialize_node(cls, op: SdkOperator | SchedulerOperator) -> dict[str, Any]: """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) @@ -1442,11 +1448,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: return serialize_op @classmethod - def populate_operator( - cls, - op: SchedulerMappedOperator | SerializedBaseOperator, - encoded_op: dict[str, Any], - ) -> None: + def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> None: """ Populate operator attributes with serialized values. @@ -1610,12 +1612,9 @@ def set_task_dag_references(task: SchedulerOperator, dag: DAG) -> None: dag.task_dict[task_id].upstream_task_ids.add(task.task_id) @classmethod - def deserialize_operator( - cls, - encoded_op: dict[str, Any], - ) -> SchedulerMappedOperator | SerializedBaseOperator: + def deserialize_operator(cls, encoded_op: dict[str, Any]) -> SchedulerOperator: """Deserializes an operator from a JSON object.""" - op: SchedulerMappedOperator | SerializedBaseOperator + op: SchedulerOperator if encoded_op.get("_is_mapped", False): # Most of these will be loaded later, these are just some stand-ins. op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} @@ -1629,26 +1628,18 @@ def deserialize_operator( op = SchedulerMappedOperator( operator_class=op_data, - expand_input=EXPAND_INPUT_EMPTY, - partial_kwargs={}, task_id=encoded_op["task_id"], - params={}, operator_extra_links=BaseOperator.operator_extra_links, template_ext=BaseOperator.template_ext, template_fields=BaseOperator.template_fields, template_fields_renderers=BaseOperator.template_fields_renderers, ui_color=BaseOperator.ui_color, ui_fgcolor=BaseOperator.ui_fgcolor, - is_empty=False, is_sensor=encoded_op.get("_is_sensor", False), can_skip_downstream=encoded_op.get("_can_skip_downstream", False), task_module=encoded_op["_task_module"], task_type=encoded_op["task_type"], operator_name=operator_name, - dag=None, - task_group=None, - start_date=None, - end_date=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], start_trigger_args=encoded_op.get("start_trigger_args", None), @@ -1753,7 +1744,7 @@ def inherits_from_empty_operator(self) -> bool: def inherits_from_skipmixin(self) -> bool: return self._can_skip_downstream - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: + def expand_start_from_trigger(self, *, context: Context) -> bool: """ Get the start_from_trigger value of the current abstract operator. @@ -2408,7 +2399,11 @@ def get_run_data_interval(self, run: DagRun) -> DataInterval | None: @attrs.define() class XComOperatorLink(LoggingMixin): - """A generic operator link class that can retrieve link only using XCOMs. Used while deserializing operators.""" + """ + Generic operator link class that can retrieve link only using XCOMs. + + Used while deserializing operators. + """ name: str xcom_key: str @@ -2449,12 +2444,10 @@ def create_scheduler_operator(op: BaseOperator | SerializedBaseOperator) -> Seri def create_scheduler_operator(op: MappedOperator | SchedulerMappedOperator) -> SchedulerMappedOperator: ... -def create_scheduler_operator( - op: BaseOperator | MappedOperator | SerializedBaseOperator | SchedulerMappedOperator, -) -> SerializedBaseOperator | SchedulerMappedOperator: +def create_scheduler_operator(op: SdkOperator | SchedulerOperator) -> SchedulerOperator: from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator - if isinstance(op, (SchedulerMappedOperator, SerializedBaseOperator)): + if isinstance(op, (SerializedBaseOperator, SchedulerMappedOperator)): return op if isinstance(op, BaseOperator): d = SerializedBaseOperator.serialize_operator(op) diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py index 2f3c8015af29c..d5922074ef5c0 100644 --- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Iterator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias, cast from sqlalchemy import select @@ -28,10 +28,14 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + class MappedTaskUpstreamDep(BaseTIDep): """ @@ -51,13 +55,17 @@ def _get_dep_statuses( session: Session, dep_context: DepContext, ) -> Iterator[TIDepStatus]: + from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions.mappedoperator import MappedOperator - if isinstance(ti.task, MappedOperator): + if ti.task is None: + return + elif is_mapped(ti.task): mapped_dependencies = ti.task.iter_mapped_dependencies() - elif ti.task is not None and (task_group := ti.task.get_closest_mapped_task_group()) is not None: - mapped_dependencies = task_group.iter_mapped_dependencies() + elif (task_group := ti.task.get_closest_mapped_task_group()) is not None: + # TODO (GH-52141): Task group in scheduler needs to return scheduler + # types instead, but currently the scheduler uses SDK's TaskGroup. + mapped_dependencies = cast("Iterator[Operator]", task_group.iter_mapped_dependencies()) else: return diff --git a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py index 1f4363c586ee1..71a6f583f2907 100644 --- a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias from sqlalchemy import func, or_, select @@ -32,9 +32,12 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.sdk.types import Operator + from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.types import Operator as SdkOperator from airflow.serialization.serialized_objects import SerializedBaseOperator + SchedulerOperator: TypeAlias = MappedOperator | SerializedBaseOperator + _SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS) @@ -107,7 +110,7 @@ def _count_unsuccessful_tis(dagrun: DagRun, task_id: str, *, session: Session) - @staticmethod def _has_unsuccessful_dependants( dagrun: DagRun, - task: Operator | SerializedBaseOperator, + task: SdkOperator | SchedulerOperator, *, session: Session, ) -> bool: diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index dd62af8324a15..a327b50c6a71d 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -149,10 +149,10 @@ def _get_expanded_ti_count() -> int: return get_mapped_ti_count(ti.task, ti.run_id, session=session) def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> Iterator[str]: - from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.models.mappedoperator import is_mapped - if isinstance(ti.task, MappedOperator): - for op in ti.task.iter_mapped_dependencies(): + if (task := ti.task) is not None and is_mapped(task): + for op in task.iter_mapped_dependencies(): yield op.task_id if task_group and task_group.iter_mapped_task_groups(): yield from ( diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index c8382df8b0e13..132ebc4b99432 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -728,10 +728,10 @@ def validate_deserialized_task( task, ): """Verify non-Airflow operators are casted to BaseOperator or MappedOperator.""" + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator from airflow.sdk import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator - assert not isinstance(task, SerializedBaseOperator) assert isinstance(task, (BaseOperator, MappedOperator)) # Every task should have a task_group property -- even if it's the DAG's root task group @@ -762,7 +762,7 @@ def validate_deserialized_task( "_is_sensor", } else: # Promised to be mapped by the assert above. - assert isinstance(serialized_task, MappedOperator) + assert isinstance(serialized_task, SchedulerMappedOperator) fields_to_check = {f.name for f in attrs.fields(MappedOperator)} fields_to_check -= { "map_index_template", @@ -829,7 +829,9 @@ def validate_deserialized_task( assert serialized_partial_kwargs == original_partial_kwargs # ExpandInputs have different classes between scheduler and definition - assert attrs.asdict(serialized_task.expand_input) == attrs.asdict(task.expand_input) + assert attrs.asdict(serialized_task._get_specified_expand_input()) == attrs.asdict( + task._get_specified_expand_input() + ) @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", diff --git a/providers/celery/src/airflow/providers/celery/version_compat.py b/providers/celery/src/airflow/providers/celery/version_compat.py index 4e4de45c3bad2..9dbb0942ea72f 100644 --- a/providers/celery/src/airflow/providers/celery/version_compat.py +++ b/providers/celery/src/airflow/providers/celery/version_compat.py @@ -31,7 +31,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: try: from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] __all__ = ["AIRFLOW_V_3_0_PLUS", "timeout"] diff --git a/providers/google/src/airflow/providers/google/version_compat.py b/providers/google/src/airflow/providers/google/version_compat.py index 2a77fd0f7163a..764970dcf34c4 100644 --- a/providers/google/src/airflow/providers/google/version_compat.py +++ b/providers/google/src/airflow/providers/google/version_compat.py @@ -61,7 +61,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: try: from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] # Explicitly export these imports to protect them from being removed by linters __all__ = [ diff --git a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py index c1093b109b735..ddf39a1898aaa 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py +++ b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py @@ -43,7 +43,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: from airflow.sdk import timezone from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils import timezone # type: ignore[no-redef, attr-defined] - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] __all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator", "timeout", "timezone"] diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 56295bae64d07..432b287f3a190 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -45,7 +45,7 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: from airflow.sdk.definitions.mappedoperator import MappedOperator else: from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - from airflow.models.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment,no-redef] return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 49be927664e97..82620170e756f 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -48,7 +48,7 @@ from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.dag import DAG # type: ignore[assignment] from airflow.models.expandinput import DictOfListsExpandInput - from airflow.models.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment,no-redef] from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index f79b60e967434..db805f5f2a598 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -23,7 +23,6 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone @@ -40,8 +39,10 @@ from airflow.models.dag_version import DagVersion from airflow.providers.standard.utils.skipmixin import SkipMixin from airflow.sdk import task, task_group + from airflow.sdk.definitions.mappedoperator import MappedOperator else: from airflow.decorators import task, task_group # type: ignore[attr-defined,no-redef] + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment] from airflow.models.skipmixin import SkipMixin DEFAULT_DATE = timezone.datetime(2016, 1, 1) diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py index 98128063dd919..d53f09122734a 100755 --- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py @@ -61,11 +61,6 @@ # Task-SDK migration ones. "deps", "downstream_task_ids", - "on_execute_callback", - "on_failure_callback", - "on_retry_callback", - "on_skipped_callback", - "on_success_callback", "operator_extra_links", "start_from_trigger", "start_trigger_args", diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 6c56a641014d9..6bd42c30b5f88 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -564,12 +564,6 @@ class DecoratedMappedOperator(MappedOperator): def __hash__(self): return id(self) - def __attrs_post_init__(self): - # The magic super() doesn't work here, so we use the explicit form. - # Not using super(..., self) to work around pyupgrade bug. - super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) - XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) - def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: # We only use op_kwargs_expand_input so this must always be empty. if self.expand_input is not EXPAND_INPUT_EMPTY: diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 398ee5649a21e..c35dc8ada7b31 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -1503,14 +1503,6 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: return DagAttributeTypes.OP, self.task_id - @property - def inherits_from_empty_operator(self): - """Used to determine if an Operator is inherited from EmptyOperator.""" - # This looks like `isinstance(self, EmptyOperator) would work, but this also - # needs to cope when `self` is a Serialized instance of a EmptyOperator or one - # of its subclasses (which don't inherit from anything but BaseOperator). - return getattr(self, "_is_empty", False) - def unmap(self, resolve: None | Mapping[str, Any]) -> Self: """ Get the "normal" operator from the current operator. @@ -1557,7 +1549,8 @@ def execute(self, context: Context) -> Any: """ Derive when creating an operator. - The main method to execute the task. Context is the same dictionary used as when rendering jinja templates. + The main method to execute the task. Context is the same dictionary used + as when rendering jinja templates. Refer to get_template_context for more context. """ diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index b101cebfd82d0..65e15b9bfa1a4 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -26,7 +26,7 @@ Iterable, Iterator, ) -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias import methodtools @@ -49,6 +49,7 @@ from airflow.sdk.definitions.taskgroup import MappedTaskGroup TaskStateChangeCallback = Callable[[Context], None] +TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 @@ -150,10 +151,6 @@ def task_type(self) -> str: def operator_name(self) -> str: raise NotImplementedError() - @property - def inherits_from_empty_operator(self) -> bool: - raise NotImplementedError() - _is_sensor: bool = False _is_mapped: bool = False _can_skip_downstream: bool = False @@ -212,6 +209,14 @@ def on_failure_fail_dagrun(self, value): ) self._on_failure_fail_dagrun = value + @property + def inherits_from_empty_operator(self): + """Used to determine if an Operator is inherited from EmptyOperator.""" + # This looks like `isinstance(self, EmptyOperator) would work, but this also + # needs to cope when `self` is a Serialized instance of a EmptyOperator or one + # of its subclasses (which don't inherit from anything but BaseOperator). + return getattr(self, "_is_empty", False) + @property def inherits_from_skipmixin(self): """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index a1c612393b7d7..4b56c9bd94e53 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -786,11 +786,12 @@ def partial_subset( """ from typing import TypeGuard + from airflow.models.mappedoperator import MappedOperator as DbMappedOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator def is_task(obj) -> TypeGuard[Operator]: - if isinstance(obj, SerializedBaseOperator): + if isinstance(obj, (DbMappedOperator, SerializedBaseOperator)): return True # TODO (GH-52141): Split DAG implementation to straight this up. return isinstance(obj, (BaseOperator, MappedOperator)) diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 551f17865e8c0..ba5595e3b4828 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -21,7 +21,7 @@ import copy import warnings from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeGuard +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard import attrs import methodtools @@ -43,7 +43,7 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, - TaskStateChangeCallback, + TaskStateChangeCallbackAttrType, ) from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, @@ -64,19 +64,13 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) - from airflow.sdk.bases.operator import BaseOperator - from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, XComArg from airflow.sdk.definitions._internal.expandinput import ExpandInput - from airflow.sdk.definitions.context import Context - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict - from airflow.sdk.definitions.taskgroup import TaskGroup - from airflow.sdk.definitions.xcom_arg import XComArg from airflow.triggers.base import StartTriggerArgs from airflow.utils.operator_resources import Resources from airflow.utils.trigger_rule import TriggerRule -TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None ValidationSource = Literal["expand"] | Literal["partial"] @@ -287,12 +281,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" - # This attribute serves double purpose. For a "normal" operator instance - # loaded from DAG, this holds the underlying non-mapped operator class that - # can be used to create an unmapped operator for execution. For an operator - # recreated from a serialized DAG, however, this holds the serialized data - # that can be used to unmap this into a SerializedBaseOperator. - operator_class: type[BaseOperator] | dict[str, Any] + operator_class: type[BaseOperator] _is_mapped: bool = attrs.field(init=False, default=True) @@ -358,7 +347,7 @@ def __attrs_post_init__(self): self.task_group.add(self) if self.dag: self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self.expand_input.value) + XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) @@ -390,11 +379,6 @@ def task_type(self) -> str: def operator_name(self) -> str: return self._operator_name - @property - def inherits_from_empty_operator(self) -> bool: - """Implementing an empty Operator.""" - return self._is_empty - @property def roots(self) -> Sequence[AbstractOperator]: """Implementing DAGNode.""" @@ -743,50 +727,26 @@ def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: """ Get the "normal" Operator after applying the current mapping. - The *resolve* argument is only used if ``operator_class`` is a real - class, i.e. if this operator is not serialized. If ``operator_class`` is - not a class (i.e. this DAG has been deserialized), this returns a - SerializedBaseOperator that "looks like" the actual unmapping result. - :meta private: """ - if isinstance(self.operator_class, type): - if isinstance(resolve, Mapping): - kwargs = resolve - elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve) - else: - raise RuntimeError("cannot unmap a non-serialized operator without context") - kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) - is_setup = kwargs.pop("is_setup", False) - is_teardown = kwargs.pop("is_teardown", False) - on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) - kwargs["task_id"] = self.task_id - op = self.operator_class(**kwargs, _airflow_from_mapped=True) - op.is_setup = is_setup - op.is_teardown = is_teardown - op.on_failure_fail_dagrun = on_failure_fail_dagrun - op.downstream_task_ids = self.downstream_task_ids - op.upstream_task_ids = self.upstream_task_ids - return op - - # TODO (GH-52141): Move this bottom part to the db-backed mapped operator implementation. - - # After a mapped operator is serialized, there's no real way to actually - # unmap it since we've lost access to the underlying operator class. - # This tries its best to simply "forward" all the attributes on this - # mapped operator to a new SerializedBaseOperator instance. - from typing import cast - - from airflow.serialization.serialized_objects import SerializedBaseOperator - - sop = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) - for partial_attr, value in self.partial_kwargs.items(): - setattr(sop, partial_attr, value) - SerializedBaseOperator.populate_operator(sop, self.operator_class) - if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. - SerializedBaseOperator.set_task_dag_references(sop, self.dag) # type: ignore[arg-type] - return cast("BaseOperator", sop) + if isinstance(resolve, Mapping): + kwargs = resolve + elif resolve is not None: + kwargs, _ = self._expand_mapped_kwargs(*resolve) + else: + raise RuntimeError("cannot unmap a non-serialized operator without context") + kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) + is_setup = kwargs.pop("is_setup", False) + is_teardown = kwargs.pop("is_teardown", False) + on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + kwargs["task_id"] = self.task_id + op = self.operator_class(**kwargs, _airflow_from_mapped=True) + op.is_setup = is_setup + op.is_teardown = is_teardown + op.on_failure_fail_dagrun = on_failure_fail_dagrun + op.downstream_task_ids = self.downstream_task_ids + op.upstream_task_ids = self.upstream_task_ids + return op def _get_specified_expand_input(self) -> ExpandInput: """Input received from the expand call on the operator.""" @@ -798,6 +758,7 @@ def prepare_for_execution(self) -> MappedOperator: # we don't need to create a copy of the MappedOperator here. return self + # TODO (GH-52141): Do we need this in the SDK? def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.sdk.definitions.xcom_arg import XComArg diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 79bddbc5edc24..041133f19e00b 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from airflow.models.expandinput import SchedulerExpandInput + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput @@ -587,8 +588,10 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: yield group group = group.parent_group - def iter_tasks(self) -> Iterator[AbstractOperator | SerializedBaseOperator]: + # TODO (GH-52141): This should only return SDK operators. Have a db representation for db operators. + def iter_tasks(self) -> Iterator[AbstractOperator | MappedOperator | SerializedBaseOperator]: """Return an iterator of the child tasks.""" + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -598,7 +601,7 @@ def iter_tasks(self) -> Iterator[AbstractOperator | SerializedBaseOperator]: visiting = groups_to_visit.pop(0) for child in visiting.children.values(): - if isinstance(child, (AbstractOperator, SerializedBaseOperator)): + if isinstance(child, (AbstractOperator, MappedOperator, SerializedBaseOperator)): yield child elif isinstance(child, TaskGroup): groups_to_visit.append(child) @@ -729,13 +732,13 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator - if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): + if isinstance(task := task_item_or_group, (AbstractOperator, MappedOperator, SerializedBaseOperator)): is_mapped = None - if isinstance(task, MappedOperator) or parent_group_is_mapped: + if task.is_mapped or parent_group_is_mapped: is_mapped = True setup_teardown_type = None if task.is_setup is True: diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 4b74c085c4c35..5ae484d0a87dc 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -24,6 +24,8 @@ from functools import singledispatch from typing import TYPE_CHECKING, Any, overload +import attrs + from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin @@ -188,6 +190,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): SetupTeardownContext.set_work_task_roots_and_leaves() +@attrs.define class PlainXComArg(XComArg): """ Reference to one single XCom without any additional semantics. @@ -207,14 +210,8 @@ class inheritance chain and ``__new__`` is implemented in this slightly :meta private: """ - def __init__(self, operator: Operator, key: str = BaseXCom.XCOM_RETURN_KEY): - self.operator = operator - self.key = key - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, PlainXComArg): - return NotImplemented - return self.operator == other.operator and self.key == other.key + operator: Operator + key: str = BaseXCom.XCOM_RETURN_KEY def __getitem__(self, item: str) -> XComArg: """Implement xcomresult['some_result_key'].""" @@ -377,10 +374,10 @@ def _get_callable_name(f: Callable | str) -> str: return "" +@attrs.define class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: - self.value = value - self.callables = callables + value: Sequence | dict + callables: MapCallables def __getitem__(self, index: Any) -> Any: value = self.value[index] @@ -393,6 +390,7 @@ def __len__(self) -> int: return len(self.value) +@attrs.define class MapXComArg(XComArg): """ An XCom reference with ``map()`` call(s) applied. @@ -403,12 +401,13 @@ class MapXComArg(XComArg): :meta private: """ - def __init__(self, arg: XComArg, callables: MapCallables) -> None: - for c in callables: + arg: XComArg + callables: MapCallables + + def __attrs_post_init__(self) -> None: + for c in self.callables: if getattr(c, "_airflow_is_task_decorator", False): raise ValueError("map() argument must be a plain function, not a @task operator") - self.arg = arg - self.callables = callables def __repr__(self) -> str: map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) @@ -434,10 +433,10 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _MapResult(value, self.callables) +@attrs.define class _ZipResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: - self.values = values - self.fillvalue = fillvalue + values: Sequence[Sequence | dict] + fillvalue: Any = attrs.field(default=NOTSET, kw_only=True) @staticmethod def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: @@ -458,6 +457,7 @@ def __len__(self) -> int: return max(lengths) +@attrs.define class ZipXComArg(XComArg): """ An XCom reference with ``zip()`` applied. @@ -467,11 +467,8 @@ class ZipXComArg(XComArg): ``itertools.zip_longest()`` if ``fillvalue`` is provided). """ - def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - self.fillvalue = fillvalue + args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1)) + fillvalue: Any = attrs.field(default=NOTSET, kw_only=True) def __repr__(self) -> str: args_iter = iter(self.args) @@ -499,9 +496,9 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ZipResult(values, fillvalue=self.fillvalue) +@attrs.define class _ConcatResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict]) -> None: - self.values = values + values: Sequence[Sequence | dict] def __getitem__(self, index: Any) -> Any: if index >= 0: @@ -523,6 +520,7 @@ def __len__(self) -> int: return sum(len(v) for v in self.values) +@attrs.define class ConcatXComArg(XComArg): """ Concatenating multiple XCom references into one. @@ -532,10 +530,7 @@ class ConcatXComArg(XComArg): return value also supports index access. """ - def __init__(self, args: Sequence[XComArg]) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args + args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1)) def __repr__(self) -> str: args_iter = iter(self.args)