diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index b8e569f877e70..8b209fb324cd8 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -23,8 +23,8 @@ import structlog from airflow.api_fastapi.common.parameters import state_priority +from airflow.models.mappedoperator import MappedOperator from airflow.models.taskmap import TaskMap -from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter from airflow.serialization.serialized_objects import SerializedBaseOperator diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 9eed649f0cad3..f112a1ebba98e 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -102,13 +102,11 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.types import ArgNotSet - Operator: TypeAlias = MappedOperator | SerializedBaseOperator - CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) - - AttributeValueType = ( + AttributeValueType: TypeAlias = ( str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] ) + Operator: TypeAlias = MappedOperator | SerializedBaseOperator RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$" @@ -1483,15 +1481,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: If the ti does not need expansion, either because the task is not mapped, or has already been expanded, *None* is returned. """ + from airflow.models.mappedoperator import is_mapped + if TYPE_CHECKING: assert ti.task if ti.map_index >= 0: # Already expanded, we're good. return None - from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator - - if isinstance(ti.task, TaskSDKMappedOperator): + if is_mapped(ti.task): # If we get here, it could be that we are moving from non-mapped to mapped # after task instance clearing or this ti is not yet expanded. Safe to clear # the db references. @@ -1510,7 +1508,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: revised_map_index_task_ids: set[str] = set() for schedulable in itertools.chain(schedulable_tis, additional_tis): if TYPE_CHECKING: - assert isinstance(schedulable.task, SerializedBaseOperator) + assert isinstance(schedulable.task, Operator) old_state = schedulable.state if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state @@ -1995,25 +1993,23 @@ def schedule_tis( empty_ti_ids: list[str] = [] schedulable_ti_ids: list[str] = [] for ti in schedulable_tis: + task = ti.task if TYPE_CHECKING: - assert isinstance(ti.task, SerializedBaseOperator) + assert isinstance(task, Operator) if ( - ti.task.inherits_from_empty_operator - and not ti.task.on_execute_callback - and not ti.task.on_success_callback - and not ti.task.outlets - and not ti.task.inlets + task.inherits_from_empty_operator + and not task.on_execute_callback + and not task.on_success_callback + and not task.outlets + and not task.inlets ): empty_ti_ids.append(ti.id) # check "start_trigger_args" to see whether the operator supports start execution from triggerer # if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer # this task. # if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker - elif ti.task.start_trigger_args is not None: - context = ti.get_template_context() - start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session) - - if start_from_trigger: + elif task.start_trigger_args is not None: + if task.expand_start_from_trigger(context=ti.get_template_context()): ti.start_date = timezone.utcnow() if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 diff --git a/airflow-core/src/airflow/models/expandinput.py b/airflow-core/src/airflow/models/expandinput.py index 6aa44316f6e77..8487022b05e75 100644 --- a/airflow-core/src/airflow/models/expandinput.py +++ b/airflow-core/src/airflow/models/expandinput.py @@ -19,18 +19,11 @@ import functools import operator -from collections.abc import Iterable, Sized +from collections.abc import Iterable, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, ClassVar import attrs -if TYPE_CHECKING: - from typing import TypeGuard - - from sqlalchemy.orm import Session - - from airflow.models.xcom_arg import SchedulerXComArg - from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, ListOfDictsExpandInput, @@ -41,6 +34,18 @@ is_mappable, ) +if TYPE_CHECKING: + from typing import TypeAlias, TypeGuard + + from sqlalchemy.orm import Session + + from airflow.models.mappedoperator import MappedOperator + from airflow.models.xcom_arg import SchedulerXComArg + from airflow.serialization.serialized_objects import SerializedBaseOperator + + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + + __all__ = [ "DictOfListsExpandInput", "ListOfDictsExpandInput", @@ -111,6 +116,26 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.referencemixin import ReferenceMixin + + for x in self.value.values(): + if isinstance(x, ReferenceMixin): + yield from x.iter_references() + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + @attrs.define class SchedulerListOfDictsExpandInput: @@ -133,6 +158,16 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: raise NotFullyPopulated({"expand_kwargs() argument"}) return length + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.referencemixin import ReferenceMixin + + if isinstance(self.value, ReferenceMixin): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, ReferenceMixin): + yield from x.iter_references() + _EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = { "dict-of-lists": SchedulerDictOfListsExpandInput, diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index 7dbd3309fd07e..5e5fe9975fdde 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -15,50 +15,65 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import functools import operator -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard import attrs +import methodtools import structlog from sqlalchemy.orm import Session from airflow.exceptions import AirflowException -from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.definitions._internal.abstractoperator import NotMapped +from airflow.sdk import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions._internal.abstractoperator import ( + DEFAULT_EXECUTOR, + DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + DEFAULT_OWNER, + DEFAULT_POOL_NAME, + DEFAULT_POOL_SLOTS, + DEFAULT_PRIORITY_WEIGHT, + DEFAULT_QUEUE, + DEFAULT_RETRIES, + DEFAULT_RETRY_DELAY, + DEFAULT_TRIGGER_RULE, + DEFAULT_WEIGHT_RULE, + NotMapped, + TaskStateChangeCallbackAttrType, +) +from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy if TYPE_CHECKING: - from collections.abc import Iterator + import datetime + from collections.abc import Collection, Iterator, Sequence + + import pendulum from airflow.models import TaskInstance from airflow.models.dag import DAG as SchedulerDAG + from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk import BaseOperatorLink - from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.param import ParamsDict from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs + from airflow.utils.operator_resources import Resources + from airflow.utils.trigger_rule import TriggerRule -log = structlog.get_logger(__name__) + Operator: TypeAlias = "SerializedBaseOperator | MappedOperator" +log = structlog.get_logger(__name__) -def _prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: - """ - Ensure *kwargs1* and *kwargs2* do not contain common keys. - :raises TypeError: If common keys are found. - """ - duplicated_keys = set(kwargs1).intersection(kwargs2) - if not duplicated_keys: - return - if len(duplicated_keys) == 1: - raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") - duplicated_keys_display = ", ".join(sorted(duplicated_keys)) - raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") +def is_mapped(task: Operator) -> TypeGuard[MappedOperator]: + return task.is_mapped @attrs.define( @@ -73,49 +88,229 @@ def _prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, getstate_setstate=False, repr=False, ) -class MappedOperator(TaskSDKMappedOperator): +# TODO (GH-52141): Duplicate DAGNode in the scheduler. +class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG.""" + operator_class: dict[str, Any] + partial_kwargs: dict[str, Any] = attrs.field(init=False, factory=dict) + + # Needed for serialization. + task_id: str + params: ParamsDict | dict = attrs.field(init=False, factory=dict) + operator_extra_links: Collection[BaseOperatorLink] + template_ext: Sequence[str] + template_fields: Collection[str] + template_fields_renderers: dict[str, str] + ui_color: str + ui_fgcolor: str + _is_empty: bool = attrs.field(alias="is_empty", init=False, default=False) + _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") + _is_sensor: bool = attrs.field(alias="is_sensor", default=False) + _task_module: str + _task_type: str + _operator_name: str + start_trigger_args: StartTriggerArgs | None + start_from_trigger: bool + _needs_expansion: bool = True + + dag: SchedulerDAG = attrs.field(init=False) + task_group: TaskGroup = attrs.field(init=False) + start_date: pendulum.DateTime | None = attrs.field(init=False, default=None) + end_date: pendulum.DateTime | None = attrs.field(init=False, default=None) + upstream_task_ids: set[str] = attrs.field(factory=set, init=False) + downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + + _disallow_kwargs_override: bool + """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. + + If *False*, values from ``expand_input`` under duplicate keys override those + under corresponding keys in ``partial_kwargs``. + """ + + _expand_input_attr: str + """Where to get kwargs to calculate expansion length against. + + This should be a name to call ``getattr()`` on. + """ + deps: frozenset[BaseTIDep] = attrs.field(init=False, default=DEFAULT_OPERATOR_DEPS) - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - """ - Get the start_from_trigger value of the current abstract operator. + is_mapped: ClassVar[bool] = True - MappedOperator uses this to unmap start_from_trigger to decide whether to start the task - execution directly from triggerer. + @property + def node_id(self) -> str: + return self.task_id - :meta private: - """ - if self.partial_kwargs.get("start_from_trigger", self.start_from_trigger): - log.warning( - "Starting a mapped task from triggerer is currently unsupported", - task_id=self.task_id, - dag_id=self.dag_id, - ) + @property + def roots(self) -> Sequence[DAGNode]: + """Required by DAGNode.""" + return [self] - # This is intentional. start_from_trigger does not work correctly with - # sdk-db separation yet, so it is disabled unconditionally for now. - # TODO: TaskSDK: Implement this properly. - return False + @property + def leaves(self) -> Sequence[DAGNode]: + """Required by DAGNode.""" + return [self] - # start_from_trigger only makes sense when start_trigger_args exists. - if not self.start_trigger_args: - return False + # TODO (GH-52141): Review if any of the properties below are used in the + # SDK and the scheduler, and remove those not needed. - mapped_kwargs, _ = self._expand_mapped_kwargs(context) - if self._disallow_kwargs_override: - _prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) + @property + def task_type(self) -> str: + """Implementing Operator.""" + return self._task_type + + @property + def operator_name(self) -> str: + return self._operator_name - # Ordering is significant; mapped kwargs should override partial ones. - return mapped_kwargs.get( - "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) + @property + def task_display_name(self) -> str: + return self.partial_kwargs.get("task_display_name") or self.task_id + + @property + def doc_md(self) -> str | None: + return self.partial_kwargs.get("doc_md") + + @property + def inherits_from_empty_operator(self) -> bool: + """Implementing an empty Operator.""" + return self._is_empty + + @property + def inherits_from_skipmixin(self) -> bool: + return self._can_skip_downstream + + @property + def owner(self) -> str: + return self.partial_kwargs.get("owner", DEFAULT_OWNER) + + @property + def trigger_rule(self) -> TriggerRule: + return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + + @property + def is_setup(self) -> bool: + return bool(self.partial_kwargs.get("is_setup")) + + @property + def is_teardown(self) -> bool: + return bool(self.partial_kwargs.get("is_teardown")) + + @property + def depends_on_past(self) -> bool: + return bool(self.partial_kwargs.get("depends_on_past")) + + @property + def ignore_first_depends_on_past(self) -> bool: + value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) + return bool(value) + + @property + def wait_for_downstream(self) -> bool: + return bool(self.partial_kwargs.get("wait_for_downstream")) + + @property + def retries(self) -> int: + return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + + @property + def queue(self) -> str: + return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + + @property + def pool(self) -> str: + return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) + + @property + def pool_slots(self) -> int: + return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) + + @property + def resources(self) -> Resources | None: + return self.partial_kwargs.get("resources") + + @property + def max_active_tis_per_dag(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dag") + + @property + def max_active_tis_per_dagrun(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dagrun") + + @property + def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_execute_callback") or [] + + @property + def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_failure_callback") or [] + + @property + def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_retry_callback") or [] + + @property + def on_success_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_success_callback") or [] + + @property + def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_skipped_callback") or [] + + @property + def run_as_user(self) -> str | None: + return self.partial_kwargs.get("run_as_user") + + @property + def priority_weight(self) -> int: + return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) + + @property + def retry_delay(self) -> datetime.timedelta: + return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) + + @property + def retry_exponential_backoff(self) -> bool: + return bool(self.partial_kwargs.get("retry_exponential_backoff")) + + @property + def weight_rule(self) -> PriorityWeightStrategy: + return validate_and_load_priority_weight_strategy( + self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) ) + @property + def executor(self) -> str | None: + return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) + + @property + def executor_config(self) -> dict: + return self.partial_kwargs.get("executor_config", {}) + + @property + def execution_timeout(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("execution_timeout") + + @property + def inlets(self) -> list[Any]: + return self.partial_kwargs.get("inlets", []) + + @property + def outlets(self) -> list[Any]: + return self.partial_kwargs.get("outlets", []) + + @property + def on_failure_fail_dagrun(self) -> bool: + return bool(self.partial_kwargs.get("on_failure_fail_dagrun")) + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, v) -> None: + self.partial_kwargs["on_failure_fail_dagrun"] = bool(v) + + def get_serialized_fields(self): + return TaskSDKMappedOperator.get_serialized_fields() + @functools.cached_property def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]: """Returns dictionary of all extra links for the operator.""" @@ -167,6 +362,104 @@ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None: return None return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives MappedOperator + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def _get_specified_expand_input(self) -> SchedulerExpandInput: + """Input received from the expand call on the operator.""" + return getattr(self, self._expand_input_attr) + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups this task belongs to. + + Groups are returned from the innermost to the outmost. + + :meta private: + """ + if (group := self.task_group) is None: + return + yield from group.iter_mapped_task_groups() + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: + """ + Get the mapped task group "closest" to this task in the DAG. + + :meta private: + """ + return next(self.iter_mapped_task_groups(), None) + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + def get_needs_expansion(self) -> bool: + """ + Return true if the task is MappedOperator or is in a mapped task group. + + :meta private: + """ + return self._needs_expansion + + # TODO (GH-52141): Copied from sdk. Find a better place for this to live in. + @methodtools.lru_cache(maxsize=1) + def get_parse_time_mapped_ti_count(self) -> int: + current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() + + def _get_parent_count() -> int: + if (group := self.get_closest_mapped_task_group()) is None: + raise NotMapped() + return group.get_parse_time_mapped_ti_count() + + try: + parent_count = _get_parent_count() + except NotMapped: + return current_count + return parent_count * current_count + + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this task for task mapping.""" + from airflow.models.xcom_arg import SchedulerXComArg + + for op, _ in SchedulerXComArg.iter_xcom_references(self._get_specified_expand_input()): + yield op + + def expand_start_from_trigger(self, *, context: Context) -> bool: + if not self.partial_kwargs.get("start_from_trigger", self.start_from_trigger): + return False + # TODO (GH-52141): Implement this. + log.warning( + "Starting a mapped task from triggerer is currently unsupported", + task_id=self.task_id, + dag_id=self.dag_id, + ) + return False + + # TODO (GH-52141): Move the implementation in SDK MappedOperator here. + def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: + raise NotImplementedError + + def unmap(self, resolve: None) -> SerializedBaseOperator: + """ + Get the "normal" Operator after applying the current mapping. + + The *resolve* argument is never used and should always be *None*. It + exists only to match the signature of the non-serialized implementation. + + The return value is a SerializedBaseOperator that "looks like" the + actual unmapping result. + + :meta private: + """ + # After a mapped operator is serialized, there's no real way to actually + # unmap it since we've lost access to the underlying operator class. + # This tries its best to simply "forward" all the attributes on this + # mapped operator to a new SerializedBaseOperator instance. + sop = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) + for partial_attr, value in self.partial_kwargs.items(): + setattr(sop, partial_attr, value) + SerializedBaseOperator.populate_operator(sop, self.operator_class) + if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. + SerializedBaseOperator.set_task_dag_references(sop, self.dag) + return sop + @functools.singledispatch def get_mapped_ti_count(task: DAGNode, run_id: str, *, session: Session) -> int: @@ -192,8 +485,6 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id: str, *, session: Ses from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef exp_input = task._get_specified_expand_input() - if isinstance(exp_input, _ExpandInputRef): - exp_input = exp_input.deref(task.dag) # TODO (GH-52141): 'task' here should be scheduler-bound and returns scheduler expand input. if not hasattr(exp_input, "get_total_map_length"): if TYPE_CHECKING: diff --git a/airflow-core/src/airflow/models/referencemixin.py b/airflow-core/src/airflow/models/referencemixin.py new file mode 100644 index 0000000000000..19a775417f865 --- /dev/null +++ b/airflow-core/src/airflow/models/referencemixin.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import TypeAlias + + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + + +@runtime_checkable +class ReferenceMixin(Protocol): + """ + Mixin for things that references a task. + + This should be implemented by things that reference operators and use them + to lazily resolve values at runtime. The most prominent examples are XCom + references (XComArg). + + This is a partial interface to the SDK's ResolveMixin with the resolve() + method removed since the scheduler should not need to resolve the reference. + """ + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + """ + Find underlying XCom references this contains. + + This is used by the DAG parser to recursively find task dependencies. + + :meta private: + """ + raise NotImplementedError diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 5ea0947f9f07a..6444d5720de20 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -295,7 +295,7 @@ class SerializedDagModel(Base): dag_runs = relationship( DagRun, - primaryjoin=dag_id == foreign(DagRun.dag_id), + primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore[has-type] backref=backref("serialized_dag", uselist=False, innerjoin=True), ) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index eac003c816cb6..8b690b417e1ef 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -120,9 +120,9 @@ from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator + from airflow.sdk import DAG from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.sdk.types import RuntimeTaskInstanceProtocol from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -1430,7 +1430,7 @@ def update_heartbeat(self): @provide_session def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None: """ - Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. + Mark the task as deferred and sets up the trigger to resume it. :meta: private """ @@ -1590,7 +1590,8 @@ def fetch_handle_failure_context( ti.clear_next_method_args() context = None - # In extreme cases (task instance heartbeat timeout in case of dag with parse error) we might _not_ have a Task. + # In extreme cases (task instance heartbeat timeout in case of dag with + # parse error) we might _not_ have a Task. if getattr(ti, "task", None): context = ti.get_template_context(session) @@ -1853,7 +1854,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "_upstream_map_indexes", { upstream.task_id: self.get_relevant_upstream_map_indexes( - cast("Operator", upstream), + upstream, expanded_ti_count, session=session, ) @@ -2288,7 +2289,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: """Whether given operator is *further* mapped inside a task group.""" - from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup if isinstance(operator, MappedOperator): diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index 1109f03bb1f99..78021e5043123 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -17,20 +17,18 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias, cast import attrs from sqlalchemy import func, or_, select from sqlalchemy.orm import Session +from airflow.models.referencemixin import ReferenceMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.sdk.definitions._internal.types import ArgNotSet -from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.xcom_arg import ( - XComArg, -) +from airflow.sdk.definitions.xcom_arg import XComArg from airflow.utils.db import exists_query from airflow.utils.state import State from airflow.utils.types import NOTSET @@ -39,11 +37,13 @@ if TYPE_CHECKING: from airflow.models.dag import DAG as SchedulerDAG - from airflow.models.operator import Operator + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.typing_compat import Self + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + -@attrs.define class SchedulerXComArg: @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: @@ -57,7 +57,33 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: dicts to the correct ``_deserialize`` information, so this function does not need to validate whether the incoming data contains correct keys. """ - raise NotImplementedError() + raise NotImplementedError("This class should not be instantiated directly") + + @classmethod + def iter_xcom_references(cls, arg: Any) -> Iterator[tuple[Operator, str]]: + """ + Return XCom references in an arbitrary value. + + Recursively traverse ``arg`` and look for XComArg instances in any + collection objects, and instances with ``template_fields`` set. + """ + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + if isinstance(arg, ReferenceMixin): + yield from arg.iter_references() + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + yield from cls.iter_xcom_references(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + yield from cls.iter_xcom_references(elem) + elif isinstance(arg, (MappedOperator, SerializedBaseOperator)): + for attr in arg.template_fields: + yield from cls.iter_xcom_references(getattr(arg, attr)) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + raise NotImplementedError("This class should not be instantiated directly") @attrs.define @@ -67,7 +93,11 @@ class SchedulerPlainXComArg(SchedulerXComArg): @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: - return cls(dag.get_task(data["task_id"]), data["key"]) + # TODO (GH-52141): SchedulerDAG should return scheduler operator instead. + return cls(cast("Operator", dag.get_task(data["task_id"])), data["key"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield self.operator, self.key @attrs.define @@ -81,6 +111,9 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: # in the UI, and displaying a function object is useless. return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + @attrs.define class SchedulerConcatXComArg(SchedulerXComArg): @@ -90,6 +123,10 @@ class SchedulerConcatXComArg(SchedulerXComArg): def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + @attrs.define class SchedulerZipXComArg(SchedulerXComArg): @@ -103,6 +140,10 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: fillvalue=data.get("fillvalue", NOTSET), ) + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + @singledispatch def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None: @@ -112,15 +153,15 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Ses @get_task_map_length.register def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): + from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id - is_mapped = xcom_arg.operator.is_mapped or isinstance(xcom_arg.operator, MappedOperator) - if is_mapped: + if is_mapped(xcom_arg.operator): unfinished_ti_exists = exists_query( TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id, diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index a8a766e553b31..3ea6a8c685c05 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -46,9 +46,7 @@ from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.models.connection import Connection from airflow.models.dag import DAG, _get_model_data_interval -from airflow.models.expandinput import ( - create_expand_input, -) +from airflow.models.expandinput import create_expand_input from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg @@ -101,14 +99,13 @@ if TYPE_CHECKING: from inspect import Parameter - from sqlalchemy.orm import Session - from airflow.models import DagRun from airflow.models.expandinput import SchedulerExpandInput from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk import DAG as SdkDag, BaseOperatorLink from airflow.serialization.json_schema import Validator + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.triggers.base import BaseEventTrigger from airflow.typing_compat import Self @@ -125,7 +122,7 @@ SchedulerOperator: TypeAlias = "SchedulerMappedOperator | SerializedBaseOperator" SdkOperator: TypeAlias = BaseOperator | MappedOperator -DEFAULT_OPERATOR_DEPS = frozenset( +DEFAULT_OPERATOR_DEPS: frozenset[BaseTIDep] = frozenset( ( NotInRetryPeriodDep(), PrevDagrunDep(), @@ -611,12 +608,12 @@ class BaseSerialization: SERIALIZER_VERSION = 2 @classmethod - def to_json(cls, var: DAG | SerializedBaseOperator | dict | list | set | tuple) -> str: + def to_json(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> str: """Stringify DAGs and operators contained by var and returns a JSON string of var.""" return json.dumps(cls.to_dict(var), ensure_ascii=True) @classmethod - def to_dict(cls, var: DAG | SerializedBaseOperator | dict | list | set | tuple) -> dict: + def to_dict(cls, var: DAG | SchedulerOperator | dict | list | set | tuple) -> dict: """Stringify DAGs and operators contained by var and returns a dict of var.""" # Don't call on this class directly - only SerializedDAG or # SerializedBaseOperator should be used as the "entrypoint" @@ -671,8 +668,8 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( cls, - # TODO (GH-52141): When can we remove SerializedBaseOperator here? - object_to_serialize: BaseOperator | MappedOperator | SerializedBaseOperator | SdkDag, + # TODO (GH-52141): When can we remove scheduler constructs here? + object_to_serialize: SdkOperator | SchedulerOperator | SdkDag | DAG, decorated_fields: set, ) -> dict[str, Any]: """Serialize an object to JSON.""" @@ -720,6 +717,8 @@ def serialize( :meta private: """ + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. if isinstance(var, enum.Enum): @@ -759,9 +758,9 @@ def serialize( return cls._encode(DeadlineAlert.serialize_deadline_alert(var), type_=DAT.DEADLINE_ALERT) elif isinstance(var, Resources): return var.to_dict() - elif isinstance(var, MappedOperator): + elif isinstance(var, (MappedOperator, SchedulerMappedOperator)): return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP) - elif isinstance(var, BaseOperator): + elif isinstance(var, (BaseOperator, SerializedBaseOperator)): var._needs_expansion = var.get_needs_expansion() return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP) elif isinstance(var, cls._datetime_types): @@ -1252,6 +1251,13 @@ def __init__( self.deps = DEFAULT_OPERATOR_DEPS self._operator_name: str | None = None + def __eq__(self, other: Any) -> bool: + if not isinstance(other, (SerializedBaseOperator, BaseOperator)): + return NotImplemented + return self.task_type == other.task_type and all( + getattr(self, c, None) == getattr(other, c, None) for c in BaseOperator._comps + ) + @property def node_id(self) -> str: return self.task_id @@ -1352,7 +1358,7 @@ def __getattr__(self, name): raise AttributeError(f"'{self.task_type}' object has no attribute '{name}'") @classmethod - def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: + def serialize_mapped_operator(cls, op: MappedOperator | SchedulerMappedOperator) -> dict[str, Any]: serialized_op = cls._serialize_node(op) # Handle expand_input and op_kwargs_expand_input. expansion_kwargs = op._get_specified_expand_input() @@ -1378,11 +1384,11 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: return serialized_op @classmethod - def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: + def serialize_operator(cls, op: SdkOperator | SchedulerOperator) -> dict[str, Any]: return cls._serialize_node(op) @classmethod - def _serialize_node(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: + def _serialize_node(cls, op: SdkOperator | SchedulerOperator) -> dict[str, Any]: """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) @@ -1442,11 +1448,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]: return serialize_op @classmethod - def populate_operator( - cls, - op: SchedulerMappedOperator | SerializedBaseOperator, - encoded_op: dict[str, Any], - ) -> None: + def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> None: """ Populate operator attributes with serialized values. @@ -1610,12 +1612,9 @@ def set_task_dag_references(task: SchedulerOperator, dag: DAG) -> None: dag.task_dict[task_id].upstream_task_ids.add(task.task_id) @classmethod - def deserialize_operator( - cls, - encoded_op: dict[str, Any], - ) -> SchedulerMappedOperator | SerializedBaseOperator: + def deserialize_operator(cls, encoded_op: dict[str, Any]) -> SchedulerOperator: """Deserializes an operator from a JSON object.""" - op: SchedulerMappedOperator | SerializedBaseOperator + op: SchedulerOperator if encoded_op.get("_is_mapped", False): # Most of these will be loaded later, these are just some stand-ins. op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} @@ -1629,26 +1628,18 @@ def deserialize_operator( op = SchedulerMappedOperator( operator_class=op_data, - expand_input=EXPAND_INPUT_EMPTY, - partial_kwargs={}, task_id=encoded_op["task_id"], - params={}, operator_extra_links=BaseOperator.operator_extra_links, template_ext=BaseOperator.template_ext, template_fields=BaseOperator.template_fields, template_fields_renderers=BaseOperator.template_fields_renderers, ui_color=BaseOperator.ui_color, ui_fgcolor=BaseOperator.ui_fgcolor, - is_empty=False, is_sensor=encoded_op.get("_is_sensor", False), can_skip_downstream=encoded_op.get("_can_skip_downstream", False), task_module=encoded_op["_task_module"], task_type=encoded_op["task_type"], operator_name=operator_name, - dag=None, - task_group=None, - start_date=None, - end_date=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], start_trigger_args=encoded_op.get("start_trigger_args", None), @@ -1753,7 +1744,7 @@ def inherits_from_empty_operator(self) -> bool: def inherits_from_skipmixin(self) -> bool: return self._can_skip_downstream - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: + def expand_start_from_trigger(self, *, context: Context) -> bool: """ Get the start_from_trigger value of the current abstract operator. @@ -2408,7 +2399,11 @@ def get_run_data_interval(self, run: DagRun) -> DataInterval | None: @attrs.define() class XComOperatorLink(LoggingMixin): - """A generic operator link class that can retrieve link only using XCOMs. Used while deserializing operators.""" + """ + Generic operator link class that can retrieve link only using XCOMs. + + Used while deserializing operators. + """ name: str xcom_key: str @@ -2449,12 +2444,10 @@ def create_scheduler_operator(op: BaseOperator | SerializedBaseOperator) -> Seri def create_scheduler_operator(op: MappedOperator | SchedulerMappedOperator) -> SchedulerMappedOperator: ... -def create_scheduler_operator( - op: BaseOperator | MappedOperator | SerializedBaseOperator | SchedulerMappedOperator, -) -> SerializedBaseOperator | SchedulerMappedOperator: +def create_scheduler_operator(op: SdkOperator | SchedulerOperator) -> SchedulerOperator: from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator - if isinstance(op, (SchedulerMappedOperator, SerializedBaseOperator)): + if isinstance(op, (SerializedBaseOperator, SchedulerMappedOperator)): return op if isinstance(op, BaseOperator): d = SerializedBaseOperator.serialize_operator(op) diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py index 2f3c8015af29c..d5922074ef5c0 100644 --- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Iterator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias, cast from sqlalchemy import select @@ -28,10 +28,14 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + class MappedTaskUpstreamDep(BaseTIDep): """ @@ -51,13 +55,17 @@ def _get_dep_statuses( session: Session, dep_context: DepContext, ) -> Iterator[TIDepStatus]: + from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions.mappedoperator import MappedOperator - if isinstance(ti.task, MappedOperator): + if ti.task is None: + return + elif is_mapped(ti.task): mapped_dependencies = ti.task.iter_mapped_dependencies() - elif ti.task is not None and (task_group := ti.task.get_closest_mapped_task_group()) is not None: - mapped_dependencies = task_group.iter_mapped_dependencies() + elif (task_group := ti.task.get_closest_mapped_task_group()) is not None: + # TODO (GH-52141): Task group in scheduler needs to return scheduler + # types instead, but currently the scheduler uses SDK's TaskGroup. + mapped_dependencies = cast("Iterator[Operator]", task_group.iter_mapped_dependencies()) else: return diff --git a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py index 1f4363c586ee1..71a6f583f2907 100644 --- a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias from sqlalchemy import func, or_, select @@ -32,9 +32,12 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.sdk.types import Operator + from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.types import Operator as SdkOperator from airflow.serialization.serialized_objects import SerializedBaseOperator + SchedulerOperator: TypeAlias = MappedOperator | SerializedBaseOperator + _SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS) @@ -107,7 +110,7 @@ def _count_unsuccessful_tis(dagrun: DagRun, task_id: str, *, session: Session) - @staticmethod def _has_unsuccessful_dependants( dagrun: DagRun, - task: Operator | SerializedBaseOperator, + task: SdkOperator | SchedulerOperator, *, session: Session, ) -> bool: diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index dd62af8324a15..a327b50c6a71d 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -149,10 +149,10 @@ def _get_expanded_ti_count() -> int: return get_mapped_ti_count(ti.task, ti.run_id, session=session) def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> Iterator[str]: - from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.models.mappedoperator import is_mapped - if isinstance(ti.task, MappedOperator): - for op in ti.task.iter_mapped_dependencies(): + if (task := ti.task) is not None and is_mapped(task): + for op in task.iter_mapped_dependencies(): yield op.task_id if task_group and task_group.iter_mapped_task_groups(): yield from ( diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index c8382df8b0e13..132ebc4b99432 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -728,10 +728,10 @@ def validate_deserialized_task( task, ): """Verify non-Airflow operators are casted to BaseOperator or MappedOperator.""" + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator from airflow.sdk import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator - assert not isinstance(task, SerializedBaseOperator) assert isinstance(task, (BaseOperator, MappedOperator)) # Every task should have a task_group property -- even if it's the DAG's root task group @@ -762,7 +762,7 @@ def validate_deserialized_task( "_is_sensor", } else: # Promised to be mapped by the assert above. - assert isinstance(serialized_task, MappedOperator) + assert isinstance(serialized_task, SchedulerMappedOperator) fields_to_check = {f.name for f in attrs.fields(MappedOperator)} fields_to_check -= { "map_index_template", @@ -829,7 +829,9 @@ def validate_deserialized_task( assert serialized_partial_kwargs == original_partial_kwargs # ExpandInputs have different classes between scheduler and definition - assert attrs.asdict(serialized_task.expand_input) == attrs.asdict(task.expand_input) + assert attrs.asdict(serialized_task._get_specified_expand_input()) == attrs.asdict( + task._get_specified_expand_input() + ) @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", diff --git a/providers/celery/src/airflow/providers/celery/version_compat.py b/providers/celery/src/airflow/providers/celery/version_compat.py index 4e4de45c3bad2..9dbb0942ea72f 100644 --- a/providers/celery/src/airflow/providers/celery/version_compat.py +++ b/providers/celery/src/airflow/providers/celery/version_compat.py @@ -31,7 +31,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: try: from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] __all__ = ["AIRFLOW_V_3_0_PLUS", "timeout"] diff --git a/providers/google/src/airflow/providers/google/version_compat.py b/providers/google/src/airflow/providers/google/version_compat.py index 2a77fd0f7163a..764970dcf34c4 100644 --- a/providers/google/src/airflow/providers/google/version_compat.py +++ b/providers/google/src/airflow/providers/google/version_compat.py @@ -61,7 +61,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: try: from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] # Explicitly export these imports to protect them from being removed by linters __all__ = [ diff --git a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py index c1093b109b735..ddf39a1898aaa 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py +++ b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py @@ -43,7 +43,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: from airflow.sdk import timezone from airflow.sdk.execution_time.timeout import timeout except ImportError: - from airflow.utils import timezone # type: ignore[no-redef, attr-defined] - from airflow.utils.timeout import timeout # type: ignore[attr-defined,no-redef] + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] + from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] __all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator", "timeout", "timezone"] diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 56295bae64d07..432b287f3a190 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -45,7 +45,7 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: from airflow.sdk.definitions.mappedoperator import MappedOperator else: from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - from airflow.models.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment,no-redef] return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 49be927664e97..82620170e756f 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -48,7 +48,7 @@ from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.dag import DAG # type: ignore[assignment] from airflow.models.expandinput import DictOfListsExpandInput - from airflow.models.mappedoperator import MappedOperator + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment,no-redef] from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index f79b60e967434..db805f5f2a598 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -23,7 +23,6 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone @@ -40,8 +39,10 @@ from airflow.models.dag_version import DagVersion from airflow.providers.standard.utils.skipmixin import SkipMixin from airflow.sdk import task, task_group + from airflow.sdk.definitions.mappedoperator import MappedOperator else: from airflow.decorators import task, task_group # type: ignore[attr-defined,no-redef] + from airflow.models.mappedoperator import MappedOperator # type: ignore[assignment] from airflow.models.skipmixin import SkipMixin DEFAULT_DATE = timezone.datetime(2016, 1, 1) diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py index 98128063dd919..d53f09122734a 100755 --- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py @@ -61,11 +61,6 @@ # Task-SDK migration ones. "deps", "downstream_task_ids", - "on_execute_callback", - "on_failure_callback", - "on_retry_callback", - "on_skipped_callback", - "on_success_callback", "operator_extra_links", "start_from_trigger", "start_trigger_args", diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 6c56a641014d9..6bd42c30b5f88 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -564,12 +564,6 @@ class DecoratedMappedOperator(MappedOperator): def __hash__(self): return id(self) - def __attrs_post_init__(self): - # The magic super() doesn't work here, so we use the explicit form. - # Not using super(..., self) to work around pyupgrade bug. - super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) - XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) - def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: # We only use op_kwargs_expand_input so this must always be empty. if self.expand_input is not EXPAND_INPUT_EMPTY: diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 398ee5649a21e..c35dc8ada7b31 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -1503,14 +1503,6 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: return DagAttributeTypes.OP, self.task_id - @property - def inherits_from_empty_operator(self): - """Used to determine if an Operator is inherited from EmptyOperator.""" - # This looks like `isinstance(self, EmptyOperator) would work, but this also - # needs to cope when `self` is a Serialized instance of a EmptyOperator or one - # of its subclasses (which don't inherit from anything but BaseOperator). - return getattr(self, "_is_empty", False) - def unmap(self, resolve: None | Mapping[str, Any]) -> Self: """ Get the "normal" operator from the current operator. @@ -1557,7 +1549,8 @@ def execute(self, context: Context) -> Any: """ Derive when creating an operator. - The main method to execute the task. Context is the same dictionary used as when rendering jinja templates. + The main method to execute the task. Context is the same dictionary used + as when rendering jinja templates. Refer to get_template_context for more context. """ diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index b101cebfd82d0..65e15b9bfa1a4 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -26,7 +26,7 @@ Iterable, Iterator, ) -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias import methodtools @@ -49,6 +49,7 @@ from airflow.sdk.definitions.taskgroup import MappedTaskGroup TaskStateChangeCallback = Callable[[Context], None] +TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 @@ -150,10 +151,6 @@ def task_type(self) -> str: def operator_name(self) -> str: raise NotImplementedError() - @property - def inherits_from_empty_operator(self) -> bool: - raise NotImplementedError() - _is_sensor: bool = False _is_mapped: bool = False _can_skip_downstream: bool = False @@ -212,6 +209,14 @@ def on_failure_fail_dagrun(self, value): ) self._on_failure_fail_dagrun = value + @property + def inherits_from_empty_operator(self): + """Used to determine if an Operator is inherited from EmptyOperator.""" + # This looks like `isinstance(self, EmptyOperator) would work, but this also + # needs to cope when `self` is a Serialized instance of a EmptyOperator or one + # of its subclasses (which don't inherit from anything but BaseOperator). + return getattr(self, "_is_empty", False) + @property def inherits_from_skipmixin(self): """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index a1c612393b7d7..4b56c9bd94e53 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -786,11 +786,12 @@ def partial_subset( """ from typing import TypeGuard + from airflow.models.mappedoperator import MappedOperator as DbMappedOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator def is_task(obj) -> TypeGuard[Operator]: - if isinstance(obj, SerializedBaseOperator): + if isinstance(obj, (DbMappedOperator, SerializedBaseOperator)): return True # TODO (GH-52141): Split DAG implementation to straight this up. return isinstance(obj, (BaseOperator, MappedOperator)) diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 551f17865e8c0..ba5595e3b4828 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -21,7 +21,7 @@ import copy import warnings from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeGuard +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard import attrs import methodtools @@ -43,7 +43,7 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, - TaskStateChangeCallback, + TaskStateChangeCallbackAttrType, ) from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, @@ -64,19 +64,13 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) - from airflow.sdk.bases.operator import BaseOperator - from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, XComArg from airflow.sdk.definitions._internal.expandinput import ExpandInput - from airflow.sdk.definitions.context import Context - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict - from airflow.sdk.definitions.taskgroup import TaskGroup - from airflow.sdk.definitions.xcom_arg import XComArg from airflow.triggers.base import StartTriggerArgs from airflow.utils.operator_resources import Resources from airflow.utils.trigger_rule import TriggerRule -TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None ValidationSource = Literal["expand"] | Literal["partial"] @@ -287,12 +281,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" - # This attribute serves double purpose. For a "normal" operator instance - # loaded from DAG, this holds the underlying non-mapped operator class that - # can be used to create an unmapped operator for execution. For an operator - # recreated from a serialized DAG, however, this holds the serialized data - # that can be used to unmap this into a SerializedBaseOperator. - operator_class: type[BaseOperator] | dict[str, Any] + operator_class: type[BaseOperator] _is_mapped: bool = attrs.field(init=False, default=True) @@ -358,7 +347,7 @@ def __attrs_post_init__(self): self.task_group.add(self) if self.dag: self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self.expand_input.value) + XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) @@ -390,11 +379,6 @@ def task_type(self) -> str: def operator_name(self) -> str: return self._operator_name - @property - def inherits_from_empty_operator(self) -> bool: - """Implementing an empty Operator.""" - return self._is_empty - @property def roots(self) -> Sequence[AbstractOperator]: """Implementing DAGNode.""" @@ -743,50 +727,26 @@ def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: """ Get the "normal" Operator after applying the current mapping. - The *resolve* argument is only used if ``operator_class`` is a real - class, i.e. if this operator is not serialized. If ``operator_class`` is - not a class (i.e. this DAG has been deserialized), this returns a - SerializedBaseOperator that "looks like" the actual unmapping result. - :meta private: """ - if isinstance(self.operator_class, type): - if isinstance(resolve, Mapping): - kwargs = resolve - elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve) - else: - raise RuntimeError("cannot unmap a non-serialized operator without context") - kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) - is_setup = kwargs.pop("is_setup", False) - is_teardown = kwargs.pop("is_teardown", False) - on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) - kwargs["task_id"] = self.task_id - op = self.operator_class(**kwargs, _airflow_from_mapped=True) - op.is_setup = is_setup - op.is_teardown = is_teardown - op.on_failure_fail_dagrun = on_failure_fail_dagrun - op.downstream_task_ids = self.downstream_task_ids - op.upstream_task_ids = self.upstream_task_ids - return op - - # TODO (GH-52141): Move this bottom part to the db-backed mapped operator implementation. - - # After a mapped operator is serialized, there's no real way to actually - # unmap it since we've lost access to the underlying operator class. - # This tries its best to simply "forward" all the attributes on this - # mapped operator to a new SerializedBaseOperator instance. - from typing import cast - - from airflow.serialization.serialized_objects import SerializedBaseOperator - - sop = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) - for partial_attr, value in self.partial_kwargs.items(): - setattr(sop, partial_attr, value) - SerializedBaseOperator.populate_operator(sop, self.operator_class) - if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. - SerializedBaseOperator.set_task_dag_references(sop, self.dag) # type: ignore[arg-type] - return cast("BaseOperator", sop) + if isinstance(resolve, Mapping): + kwargs = resolve + elif resolve is not None: + kwargs, _ = self._expand_mapped_kwargs(*resolve) + else: + raise RuntimeError("cannot unmap a non-serialized operator without context") + kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) + is_setup = kwargs.pop("is_setup", False) + is_teardown = kwargs.pop("is_teardown", False) + on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + kwargs["task_id"] = self.task_id + op = self.operator_class(**kwargs, _airflow_from_mapped=True) + op.is_setup = is_setup + op.is_teardown = is_teardown + op.on_failure_fail_dagrun = on_failure_fail_dagrun + op.downstream_task_ids = self.downstream_task_ids + op.upstream_task_ids = self.upstream_task_ids + return op def _get_specified_expand_input(self) -> ExpandInput: """Input received from the expand call on the operator.""" @@ -798,6 +758,7 @@ def prepare_for_execution(self) -> MappedOperator: # we don't need to create a copy of the MappedOperator here. return self + # TODO (GH-52141): Do we need this in the SDK? def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.sdk.definitions.xcom_arg import XComArg diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 79bddbc5edc24..041133f19e00b 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from airflow.models.expandinput import SchedulerExpandInput + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput @@ -587,8 +588,10 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: yield group group = group.parent_group - def iter_tasks(self) -> Iterator[AbstractOperator | SerializedBaseOperator]: + # TODO (GH-52141): This should only return SDK operators. Have a db representation for db operators. + def iter_tasks(self) -> Iterator[AbstractOperator | MappedOperator | SerializedBaseOperator]: """Return an iterator of the child tasks.""" + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -598,7 +601,7 @@ def iter_tasks(self) -> Iterator[AbstractOperator | SerializedBaseOperator]: visiting = groups_to_visit.pop(0) for child in visiting.children.values(): - if isinstance(child, (AbstractOperator, SerializedBaseOperator)): + if isinstance(child, (AbstractOperator, MappedOperator, SerializedBaseOperator)): yield child elif isinstance(child, TaskGroup): groups_to_visit.append(child) @@ -729,13 +732,13 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator - if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): + if isinstance(task := task_item_or_group, (AbstractOperator, MappedOperator, SerializedBaseOperator)): is_mapped = None - if isinstance(task, MappedOperator) or parent_group_is_mapped: + if task.is_mapped or parent_group_is_mapped: is_mapped = True setup_teardown_type = None if task.is_setup is True: diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 4b74c085c4c35..5ae484d0a87dc 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -24,6 +24,8 @@ from functools import singledispatch from typing import TYPE_CHECKING, Any, overload +import attrs + from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin @@ -188,6 +190,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): SetupTeardownContext.set_work_task_roots_and_leaves() +@attrs.define class PlainXComArg(XComArg): """ Reference to one single XCom without any additional semantics. @@ -207,14 +210,8 @@ class inheritance chain and ``__new__`` is implemented in this slightly :meta private: """ - def __init__(self, operator: Operator, key: str = BaseXCom.XCOM_RETURN_KEY): - self.operator = operator - self.key = key - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, PlainXComArg): - return NotImplemented - return self.operator == other.operator and self.key == other.key + operator: Operator + key: str = BaseXCom.XCOM_RETURN_KEY def __getitem__(self, item: str) -> XComArg: """Implement xcomresult['some_result_key'].""" @@ -377,10 +374,10 @@ def _get_callable_name(f: Callable | str) -> str: return "" +@attrs.define class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: - self.value = value - self.callables = callables + value: Sequence | dict + callables: MapCallables def __getitem__(self, index: Any) -> Any: value = self.value[index] @@ -393,6 +390,7 @@ def __len__(self) -> int: return len(self.value) +@attrs.define class MapXComArg(XComArg): """ An XCom reference with ``map()`` call(s) applied. @@ -403,12 +401,13 @@ class MapXComArg(XComArg): :meta private: """ - def __init__(self, arg: XComArg, callables: MapCallables) -> None: - for c in callables: + arg: XComArg + callables: MapCallables + + def __attrs_post_init__(self) -> None: + for c in self.callables: if getattr(c, "_airflow_is_task_decorator", False): raise ValueError("map() argument must be a plain function, not a @task operator") - self.arg = arg - self.callables = callables def __repr__(self) -> str: map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) @@ -434,10 +433,10 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _MapResult(value, self.callables) +@attrs.define class _ZipResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: - self.values = values - self.fillvalue = fillvalue + values: Sequence[Sequence | dict] + fillvalue: Any = attrs.field(default=NOTSET, kw_only=True) @staticmethod def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: @@ -458,6 +457,7 @@ def __len__(self) -> int: return max(lengths) +@attrs.define class ZipXComArg(XComArg): """ An XCom reference with ``zip()`` applied. @@ -467,11 +467,8 @@ class ZipXComArg(XComArg): ``itertools.zip_longest()`` if ``fillvalue`` is provided). """ - def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - self.fillvalue = fillvalue + args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1)) + fillvalue: Any = attrs.field(default=NOTSET, kw_only=True) def __repr__(self) -> str: args_iter = iter(self.args) @@ -499,9 +496,9 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ZipResult(values, fillvalue=self.fillvalue) +@attrs.define class _ConcatResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict]) -> None: - self.values = values + values: Sequence[Sequence | dict] def __getitem__(self, index: Any) -> Any: if index >= 0: @@ -523,6 +520,7 @@ def __len__(self) -> int: return sum(len(v) for v in self.values) +@attrs.define class ConcatXComArg(XComArg): """ Concatenating multiple XCom references into one. @@ -532,10 +530,7 @@ class ConcatXComArg(XComArg): return value also supports index access. """ - def __init__(self, args: Sequence[XComArg]) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args + args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1)) def __repr__(self) -> str: args_iter = iter(self.args)