Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import structlog

from airflow.api_fastapi.common.parameters import state_priority
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmap import TaskMap
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter
from airflow.serialization.serialized_objects import SerializedBaseOperator

Expand Down
34 changes: 15 additions & 19 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,11 @@
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.types import ArgNotSet

Operator: TypeAlias = MappedOperator | SerializedBaseOperator

CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])

AttributeValueType = (
AttributeValueType: TypeAlias = (
str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
)
Operator: TypeAlias = MappedOperator | SerializedBaseOperator

RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"

Expand Down Expand Up @@ -1483,15 +1481,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
If the ti does not need expansion, either because the task is not
mapped, or has already been expanded, *None* is returned.
"""
from airflow.models.mappedoperator import is_mapped

if TYPE_CHECKING:
assert ti.task

if ti.map_index >= 0: # Already expanded, we're good.
return None

from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator

if isinstance(ti.task, TaskSDKMappedOperator):
if is_mapped(ti.task):
# If we get here, it could be that we are moving from non-mapped to mapped
# after task instance clearing or this ti is not yet expanded. Safe to clear
# the db references.
Expand All @@ -1510,7 +1508,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
revised_map_index_task_ids: set[str] = set()
for schedulable in itertools.chain(schedulable_tis, additional_tis):
if TYPE_CHECKING:
assert isinstance(schedulable.task, SerializedBaseOperator)
assert isinstance(schedulable.task, Operator)
old_state = schedulable.state
if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
old_states[schedulable.key] = old_state
Expand Down Expand Up @@ -1995,25 +1993,23 @@ def schedule_tis(
empty_ti_ids: list[str] = []
schedulable_ti_ids: list[str] = []
for ti in schedulable_tis:
task = ti.task
if TYPE_CHECKING:
assert isinstance(ti.task, SerializedBaseOperator)
assert isinstance(task, Operator)
if (
ti.task.inherits_from_empty_operator
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
and not ti.task.inlets
task.inherits_from_empty_operator
and not task.on_execute_callback
and not task.on_success_callback
and not task.outlets
and not task.inlets
):
empty_ti_ids.append(ti.id)
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
# this task.
# if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker
elif ti.task.start_trigger_args is not None:
context = ti.get_template_context()
start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session)

if start_from_trigger:
elif task.start_trigger_args is not None:
if task.expand_start_from_trigger(context=ti.get_template_context()):
ti.start_date = timezone.utcnow()
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
Expand Down
51 changes: 43 additions & 8 deletions airflow-core/src/airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,11 @@

import functools
import operator
from collections.abc import Iterable, Sized
from collections.abc import Iterable, Mapping, Sequence, Sized
from typing import TYPE_CHECKING, Any, ClassVar

import attrs

if TYPE_CHECKING:
from typing import TypeGuard

from sqlalchemy.orm import Session

from airflow.models.xcom_arg import SchedulerXComArg

from airflow.sdk.definitions._internal.expandinput import (
DictOfListsExpandInput,
ListOfDictsExpandInput,
Expand All @@ -41,6 +34,18 @@
is_mappable,
)

if TYPE_CHECKING:
from typing import TypeAlias, TypeGuard

from sqlalchemy.orm import Session

from airflow.models.mappedoperator import MappedOperator
from airflow.models.xcom_arg import SchedulerXComArg
from airflow.serialization.serialized_objects import SerializedBaseOperator

Operator: TypeAlias = MappedOperator | SerializedBaseOperator


__all__ = [
"DictOfListsExpandInput",
"ListOfDictsExpandInput",
Expand Down Expand Up @@ -111,6 +116,26 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
lengths = self._get_map_lengths(run_id, session=session)
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)

def iter_references(self) -> Iterable[tuple[Operator, str]]:
from airflow.models.referencemixin import ReferenceMixin

for x in self.value.values():
if isinstance(x, ReferenceMixin):
yield from x.iter_references()


# To replace tedious isinstance() checks.
def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
from airflow.sdk.definitions.xcom_arg import XComArg

return not isinstance(v, (MappedArgument, XComArg))


def _describe_type(value: Any) -> str:
if value is None:
return "None"
return type(value).__name__


@attrs.define
class SchedulerListOfDictsExpandInput:
Expand All @@ -133,6 +158,16 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
raise NotFullyPopulated({"expand_kwargs() argument"})
return length

def iter_references(self) -> Iterable[tuple[Operator, str]]:
from airflow.models.referencemixin import ReferenceMixin

if isinstance(self.value, ReferenceMixin):
yield from self.value.iter_references()
else:
for x in self.value:
if isinstance(x, ReferenceMixin):
yield from x.iter_references()


_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
"dict-of-lists": SchedulerDictOfListsExpandInput,
Expand Down
Loading
Loading