diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 468a6308e6032..fb2c514349556 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -21,7 +21,7 @@ dynamic = ["version"] description = "Python Task SDK for Apache Airflow DAG Authors" readme = { file = "README.md", content-type = "text/markdown" } license-files.globs = ["LICENSE"] -requires-python = ">=3.9, <3.13" +requires-python = ">=3.10, <3.13" authors = [ {name="Apache Software Foundation", email="dev@airflow.apache.org"}, @@ -38,7 +38,6 @@ classifiers = [ "Intended Audience :: System Administrators", "Framework :: Apache Airflow", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -169,14 +168,14 @@ enum-field-as-literal='one' # When a single enum member, make it output a `Liter input-file-type='openapi' output-model-type='pydantic_v2.BaseModel' output-datetime-class='AwareDatetime' -target-python-version='3.9' +target-python-version='3.10' use-annotated=true use-default=true use-double-quotes=true use-schema-description=true # Desc becomes class doc comment use-standard-collections=true # list[] not List[] use-subclass-enum=true # enum, not union of Literals -use-union-operator=true # 3.9+annotations, not `Union[]` +use-union-operator=true # annotations, not `Union[]` custom-formatters = ['datamodel_code_formatter'] url = 'http://0.0.0.0:8080/execution/openapi.json' diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index c1063774220de..fb77b079a3386 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -21,9 +21,9 @@ import re import textwrap import warnings -from collections.abc import Collection, Iterator, Mapping, Sequence +from collections.abc import Callable, Collection, Iterator, Mapping, Sequence from functools import cached_property, update_wrapper -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Protocol, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, cast, overload import attr import typing_extensions @@ -424,7 +424,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = ) if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, (XComArg, Mapping)): + if not isinstance(item, XComArg | Mapping): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") diff --git a/task-sdk/src/airflow/sdk/bases/sensor.py b/task-sdk/src/airflow/sdk/bases/sensor.py index ca2b80cb97eb6..a912476322c15 100644 --- a/task-sdk/src/airflow/sdk/bases/sensor.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -21,9 +21,9 @@ import hashlib import time import traceback -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from airflow.configuration import conf from airflow.exceptions import ( @@ -143,7 +143,7 @@ def __init__( def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta: if isinstance(poke_interval, timedelta): return poke_interval - if isinstance(poke_interval, (int, float)) and poke_interval >= 0: + if isinstance(poke_interval, int | float) and poke_interval >= 0: return timedelta(seconds=poke_interval) raise AirflowException( "Operator arg `poke_interval` must be timedelta object or a non-negative number" @@ -153,7 +153,7 @@ def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta: def _coerce_timeout(timeout: float | timedelta) -> timedelta: if isinstance(timeout, timedelta): return timeout - if isinstance(timeout, (int, float)) and timeout >= 0: + if isinstance(timeout, int | float) and timeout >= 0: return timedelta(seconds=timeout) raise AirflowException("Operator arg `timeout` must be timedelta object or a non-negative number") @@ -161,14 +161,14 @@ def _coerce_timeout(timeout: float | timedelta) -> timedelta: def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: if max_wait is None or isinstance(max_wait, timedelta): return max_wait - if isinstance(max_wait, (int, float)) and max_wait >= 0: + if isinstance(max_wait, int | float) and max_wait >= 0: return timedelta(seconds=max_wait) raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number") def _validate_input_values(self) -> None: - if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: + if not isinstance(self.poke_interval, int | float) or self.poke_interval < 0: raise AirflowException("The poke_interval must be a non-negative number") - if not isinstance(self.timeout, (int, float)) or self.timeout < 0: + if not isinstance(self.timeout, int | float) or self.timeout < 0: raise AirflowException("The timeout must be a non-negative number") if self.mode not in self.valid_modes: raise AirflowException( diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index ec2fefa0a08b4..4180c4edaa4db 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -426,7 +426,7 @@ def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: for key, child in _walk_group(dag.task_group): if key == self.node_id: continue - if not isinstance(child, (MappedOperator, MappedTaskGroup)): + if not isinstance(child, MappedOperator | MappedTaskGroup): continue if self.node_id in child.upstream_task_ids: yield child diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py index b1c0c6ee5f979..00b12e24b399d 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -62,21 +62,21 @@ def __str__(self) -> str: def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: from airflow.sdk.definitions.xcom_arg import XComArg - return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + return isinstance(v, MappedArgument | XComArg | Mapping | Sequence) and not isinstance(v, str) # To replace tedious isinstance() checks. def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: from airflow.sdk.definitions.xcom_arg import XComArg - return not isinstance(v, (MappedArgument, XComArg)) + return not isinstance(v, MappedArgument | XComArg) # To replace tedious isinstance() checks. def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: from airflow.sdk.definitions.xcom_arg import XComArg - return isinstance(v, (MappedArgument, XComArg)) + return isinstance(v, MappedArgument | XComArg) @attrs.define(kw_only=True) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index 21fa4ede5b1c9..177111af541c0 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -168,7 +168,7 @@ def _set_relatives( task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) relatives = task_object.leaves if upstream else task_object.roots for task in relatives: - if not isinstance(task, (BaseOperator, MappedOperator)): + if not isinstance(task, BaseOperator | MappedOperator): raise TypeError( f"Relationships can only be set between Operators; received {task.__class__.__name__}" ) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 9cb913807ee91..3fd9985521b74 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -23,7 +23,8 @@ import os import urllib.parse import warnings -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, Union, overload +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload import attrs @@ -117,7 +118,7 @@ def to_asset_alias(self) -> AssetAlias: return AssetAlias(name=self.name) -BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey] +BaseAssetUniqueKey = AssetUniqueKey | AssetAliasUniqueKey def normalize_noop(parts: SplitResult) -> SplitResult: diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index c0a4230377449..5291a38651486 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -25,15 +25,13 @@ import sys import weakref from collections import abc -from collections.abc import Collection, Iterable, MutableSet +from collections.abc import Callable, Collection, Iterable, MutableSet from datetime import datetime, timedelta from inspect import signature from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Union, cast, overload, ) @@ -93,7 +91,7 @@ DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = None | str | timedelta | relativedelta -ScheduleArg = Union[ScheduleInterval, Timetable, BaseAsset, Collection[BaseAsset]] +ScheduleArg = ScheduleInterval | Timetable | BaseAsset | Collection[BaseAsset] _DAG_HASH_ATTRS = frozenset( @@ -124,7 +122,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime return OnceTimetable() if interval == "@continuous": return ContinuousTimetable() - if isinstance(interval, (timedelta, relativedelta)): + if isinstance(interval, timedelta | relativedelta): if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): return DeltaDataIntervalTimetable(interval) return DeltaTriggerTimetable(interval) @@ -809,7 +807,7 @@ def partial_subset( direct_upstreams: list[Operator] = [] if include_direct_upstream: for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) + upstream = (u for u in t.upstream_list if isinstance(u, BaseOperator | MappedOperator)) direct_upstreams.extend(upstream) # Make sure to not recursively deepcopy the dag or task_group while copying the task. @@ -1284,12 +1282,7 @@ def _run_inline_trigger(trigger): import asyncio async def _run_inline_trigger_main(): - # We can replace it with `return await anext(trigger.run(), default=None)` - # when we drop support for Python 3.9 - try: - return await trigger.run().__anext__() - except StopAsyncIteration: - return None + return await anext(trigger.run(), None) return asyncio.run(_run_inline_trigger_main()) diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index 3c775b5064364..2b4a44af0adcf 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -17,8 +17,9 @@ from __future__ import annotations import logging +from collections.abc import Callable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from airflow.models.deadline import ReferenceModels from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py index b6c3c879faf8c..41a7c2d0bf290 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Callable +from collections.abc import Callable from airflow.providers_manager import ProvidersManager from airflow.sdk.bases.decorator import TaskDecorator diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi index 30e921f2f4881..e60852a3f02d9 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi @@ -20,9 +20,9 @@ # documentation for more details. from __future__ import annotations -from collections.abc import Collection, Container, Iterable, Mapping +from collections.abc import Callable, Collection, Container, Iterable, Mapping from datetime import timedelta -from typing import Any, Callable, TypeVar, overload +from typing import Any, TypeVar, overload from docker.types import Mount from kubernetes.client import models as k8s diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/condition.py b/task-sdk/src/airflow/sdk/definitions/decorators/condition.py index 5ccf6b685d497..2808563ffcfd5 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/condition.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/condition.py @@ -16,14 +16,15 @@ # under the License. from __future__ import annotations +from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from airflow.exceptions import AirflowSkipException from airflow.sdk.bases.decorator import Task, _TaskDecorator if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing import TypeAlias from airflow.sdk.bases.operator import TaskPreExecuteHook from airflow.sdk.definitions.context import Context diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py b/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py index c53a84ea71bc8..e5e2bf40ee991 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py @@ -17,7 +17,8 @@ from __future__ import annotations import types -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING from airflow.exceptions import AirflowException from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py index bb718abdb2021..809f5889e003e 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py @@ -28,8 +28,8 @@ import functools import inspect import warnings -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, overload +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload import attr @@ -144,7 +144,7 @@ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode: def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode: if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, (XComArg, Mapping)): + if not isinstance(item, XComArg | Mapping): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") diff --git a/task-sdk/src/airflow/sdk/definitions/edges.py b/task-sdk/src/airflow/sdk/definitions/edges.py index 39fafc4b932c8..4a52b9df32c4e 100644 --- a/task-sdk/src/airflow/sdk/definitions/edges.py +++ b/task-sdk/src/airflow/sdk/definitions/edges.py @@ -75,7 +75,7 @@ def _save_nodes( from airflow.sdk.definitions.xcom_arg import XComArg for node in self._make_list(nodes): - if isinstance(node, (TaskGroup, XComArg, DAGNode)): + if isinstance(node, TaskGroup | XComArg | DAGNode): stream.append(node) else: raise TypeError( diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 58c03c31b87d7..9f2cc18604354 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -21,11 +21,12 @@ import copy import warnings from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Union +from typing import TYPE_CHECKING, Any, ClassVar import attrs import methodtools +from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -60,9 +61,6 @@ import jinja2 # Slow import. import pendulum - from airflow.models.abstractoperator import ( - TaskStateChangeCallback, - ) from airflow.models.expandinput import ( OperatorExpandArgument, OperatorExpandKwargsArgument, @@ -73,7 +71,6 @@ from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.xcom_arg import XComArg - from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import TypeGuard @@ -82,9 +79,8 @@ from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule - TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] - -ValidationSource = Union[Literal["expand"], Literal["partial"]] +TaskStateChangeCallbackAttrType = TaskStateChangeCallback | list[TaskStateChangeCallback] | None +ValidationSource = Literal["expand"] | Literal["partial"] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: @@ -144,9 +140,9 @@ def is_mappable_value(value: Any) -> TypeGuard[Collection]: :meta private: """ - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, Sequence | dict): return False - if isinstance(value, (bytearray, bytes, str)): + if isinstance(value, bytearray | bytes | str): return False return True @@ -196,7 +192,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, (XComArg, Mapping)): + if not isinstance(item, XComArg | Mapping): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") @@ -786,7 +782,7 @@ def prepare_for_execution(self) -> MappedOperator: # we don't need to create a copy of the MappedOperator here. return self - def iter_mapped_dependencies(self) -> Iterator[Operator]: + def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.sdk.definitions.xcom_arg import XComArg diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 2a93585304cb0..7a5ed0468e739 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -20,9 +20,9 @@ import contextlib import inspect import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator @@ -104,7 +104,7 @@ def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: """ if isinstance(arg, ResolveMixin): yield from arg.iter_references() - elif isinstance(arg, (tuple, set, list)): + elif isinstance(arg, tuple | set | list): for elem in arg: yield from XComArg.iter_xcom_references(elem) elif isinstance(arg, dict): @@ -429,7 +429,7 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, Sequence | dict): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) @@ -494,7 +494,7 @@ def iter_references(self) -> Iterator[tuple[Operator, str]]: def resolve(self, context: Mapping[str, Any]) -> Any: values = [arg.resolve(context) for arg in self.args] for value in values: - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, Sequence | dict): raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") return _ZipResult(values, fillvalue=self.fillvalue) @@ -557,7 +557,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: values = [arg.resolve(context) for arg in self.args] for value in values: - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, Sequence | dict): raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") return _ConcatResult(values) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py index f35d76d058915..316c3d38e99b8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py @@ -20,7 +20,8 @@ import inspect import logging -from typing import TYPE_CHECKING, Callable, Generic, Protocol, TypeVar, cast +from collections.abc import Callable +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast from typing_extensions import ParamSpec diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 97ed1761ad8eb..338f5df23a3bb 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -54,7 +54,7 @@ from functools import cached_property from pathlib import Path from socket import socket -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload from uuid import UUID import attrs @@ -558,27 +558,25 @@ class SentFDs(BaseModel): ToTask = Annotated[ - Union[ - AssetResult, - AssetEventsResult, - ConnectionResult, - DagRunStateResult, - DRCount, - ErrorResponse, - PrevSuccessfulDagRunResult, - SentFDs, - StartupDetails, - TaskRescheduleStartDate, - TICount, - TaskStatesResult, - VariableResult, - XComCountResponse, - XComResult, - XComSequenceIndexResult, - XComSequenceSliceResult, - InactiveAssetsResult, - OKResponse, - ], + AssetResult + | AssetEventsResult + | ConnectionResult + | DagRunStateResult + | DRCount + | ErrorResponse + | PrevSuccessfulDagRunResult + | SentFDs + | StartupDetails + | TaskRescheduleStartDate + | TICount + | TaskStatesResult + | VariableResult + | XComCountResponse + | XComResult + | XComSequenceIndexResult + | XComSequenceSliceResult + | InactiveAssetsResult + | OKResponse, Field(discriminator="type"), ] @@ -841,37 +839,35 @@ class GetDRCount(BaseModel): ToSupervisor = Annotated[ - Union[ - DeferTask, - DeleteXCom, - GetAssetByName, - GetAssetByUri, - GetAssetEventByAsset, - GetAssetEventByAssetAlias, - GetConnection, - GetDagRunState, - GetDRCount, - GetPrevSuccessfulDagRun, - GetTaskRescheduleStartDate, - GetTICount, - GetTaskStates, - GetVariable, - GetXCom, - GetXComCount, - GetXComSequenceItem, - GetXComSequenceSlice, - PutVariable, - RescheduleTask, - RetryTask, - SetRenderedFields, - SetXCom, - SkipDownstreamTasks, - SucceedTask, - ValidateInletsAndOutlets, - TaskState, - TriggerDagRun, - DeleteVariable, - ResendLoggingFD, - ], + DeferTask + | DeleteXCom + | GetAssetByName + | GetAssetByUri + | GetAssetEventByAsset + | GetAssetEventByAssetAlias + | GetConnection + | GetDagRunState + | GetDRCount + | GetPrevSuccessfulDagRun + | GetTaskRescheduleStartDate + | GetTICount + | GetTaskStates + | GetVariable + | GetXCom + | GetXComCount + | GetXComSequenceItem + | GetXComSequenceSlice + | PutVariable + | RescheduleTask + | RetryTask + | SetRenderedFields + | SetXCom + | SkipDownstreamTasks + | SucceedTask + | ValidateInletsAndOutlets + | TaskState + | TriggerDagRun + | DeleteVariable + | ResendLoggingFD, Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index c76994995ebab..e95b9b173dc90 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -529,7 +529,7 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] - if not isinstance(obj, (Asset, AssetAlias, AssetRef)): + if not isinstance(obj, Asset | AssetAlias | AssetRef): raise IndexError(key) else: obj = key diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index cfc6214325dc4..560a493d3bc3f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -23,28 +23,20 @@ import logging import re import sys -from collections.abc import Generator, Iterable, Iterator +from collections.abc import Callable, Generator, Iterable, Iterator from enum import Enum from functools import cache, cached_property from re import Pattern -from typing import ( - TYPE_CHECKING, - Any, - Callable, - TextIO, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, TextIO, TypeAlias, TypeVar from airflow import settings if TYPE_CHECKING: - from kubernetes.client import V1EnvVar - from airflow.typing_compat import TypeGuard -Redactable = TypeVar("Redactable", str, "V1EnvVar", dict[Any, Any], tuple[Any, ...], list[Any]) -Redacted = Union[Redactable, str] +V1EnvVar = TypeVar("V1EnvVar") +Redactable: TypeAlias = str | V1EnvVar | dict[Any, Any] | tuple[Any, ...] | list[Any] +Redacted: TypeAlias = Redactable | str log = logging.getLogger(__name__) @@ -240,7 +232,7 @@ def _redact_all(self, item: Redactable, depth: int, max_depth: int = MAX_RECURSI return { dict_key: self._redact_all(subval, depth + 1, max_depth) for dict_key, subval in item.items() } - if isinstance(item, (tuple, set)): + if isinstance(item, tuple | set): # Turn set in to tuple! return tuple(self._redact_all(subval, depth + 1, max_depth) for subval in item) if isinstance(item, list): @@ -265,7 +257,7 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int if isinstance(item, Enum): return self._redact(item=item.value, name=name, depth=depth, max_depth=max_depth) if _is_v1_env_var(item): - tmp: dict = item.to_dict() + tmp: dict = item.to_dict() # type: ignore[attr-defined] # V1EnvVar has a to_dict method if should_hide_value_for_key(tmp.get("name", "")) and "value" in tmp: tmp["value"] = "***" else: @@ -278,7 +270,7 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int # the structure. return self.replacer.sub("***", str(item)) return item - if isinstance(item, (tuple, set)): + if isinstance(item, tuple | set): # Turn set in to tuple! return tuple( self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item @@ -462,7 +454,7 @@ def writable(self) -> bool: return self.target.writable() def write(self, s: str) -> int: - s = redact(s) + s = str(redact(s)) return self.target.write(s) def writelines(self, lines) -> None: diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 70b384a8c96b2..2bae008a0a6ff 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -29,7 +29,7 @@ import time import weakref from collections import deque -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus @@ -37,7 +37,6 @@ from typing import ( TYPE_CHECKING, BinaryIO, - Callable, ClassVar, NoReturn, TextIO, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6c6e597f65e5c..24127d65a2763 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -340,8 +340,8 @@ def xcom_pull( if run_id is None: run_id = self.run_id - single_task_requested = isinstance(task_ids, (str, type(None))) - single_map_index_requested = isinstance(map_indexes, (int, type(None))) + single_task_requested = isinstance(task_ids, str | type(None)) + single_map_index_requested = isinstance(map_indexes, int | type(None)) if task_ids is None: # default to the current task if not provided @@ -618,7 +618,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) exit(1) - if not isinstance(task, (BaseOperator, MappedOperator)): + if not isinstance(task, BaseOperator | MappedOperator): raise TypeError( f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" ) diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 554a96cd9e690..46efd6bf448ed 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -24,9 +24,10 @@ import re import sys import warnings +from collections.abc import Callable from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar, cast +from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TextIO, TypeVar, cast import msgspec import structlog diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 8bd0ea0db8d4d..400ad23783325 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -19,7 +19,7 @@ import uuid from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Protocol, Union +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet @@ -33,7 +33,7 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator - Operator = Union[BaseOperator, MappedOperator] + Operator: TypeAlias = BaseOperator | MappedOperator class DagRunProtocol(Protocol): diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index 80c71f41a8fe8..f2660faa2d206 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -119,7 +119,7 @@ def captured_logs(request): # We need to replace remove the last processor (the one that turns JSON into text, as we want the # event dict for tests) proc = processors.pop() - assert isinstance(proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer)), ( + assert isinstance(proc, structlog.dev.ConsoleRenderer | structlog.processors.JSONRenderer), ( "Pre-condition" ) try: diff --git a/task-sdk/tests/task_sdk/definitions/conftest.py b/task-sdk/tests/task_sdk/definitions/conftest.py index 3f89f34b4d2da..7ad358487ba63 100644 --- a/task-sdk/tests/task_sdk/definitions/conftest.py +++ b/task-sdk/tests/task_sdk/definitions/conftest.py @@ -42,7 +42,7 @@ def run(dag: DAG, task_id: str, map_index: int): for call in mock_supervisor_comms.send.mock_calls: msg = call.kwargs.get("msg") or call.args[0] - if isinstance(msg, (TaskState, SucceedTask)): + if isinstance(msg, TaskState | SucceedTask): return msg.state raise RuntimeError("Unable to find call to TaskState") diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 2a25c0907c7dc..fd70882e96ada 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -19,7 +19,7 @@ import json import os -from typing import Callable +from collections.abc import Callable from unittest import mock import pytest @@ -244,7 +244,7 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: if isinstance(a1, Asset) and isinstance(a2, Asset): return a1.uri == a2.uri - if isinstance(a1, (AssetAny, AssetAll)) and isinstance(a2, (AssetAny, AssetAll)): + if isinstance(a1, AssetAny | AssetAll) and isinstance(a2, AssetAny | AssetAll): if len(a1.objects) != len(a2.objects): return False diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index 5c81b64b605b3..6889d7d57c3b0 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -17,8 +17,9 @@ # under the License. from __future__ import annotations +from collections.abc import Callable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest import mock import pendulum diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 73468dcb9e12b..521fa99df811f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Callable +from collections.abc import Callable from unittest import mock import pytest