diff --git a/airflow-core/src/airflow/dag_processing/dagbag.py b/airflow-core/src/airflow/dag_processing/dagbag.py index cadd26412c96d..173f5b05b4e2f 100644 --- a/airflow-core/src/airflow/dag_processing/dagbag.py +++ b/airflow-core/src/airflow/dag_processing/dagbag.py @@ -49,6 +49,7 @@ ) from airflow.executors.executor_loader import ExecutorLoader from airflow.listeners.listener import get_listener_manager +from airflow.serialization.definitions.notset import NOTSET, ArgNotSet, is_arg_set from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.utils.docs import get_docs_url from airflow.utils.file import ( @@ -59,7 +60,6 @@ ) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET if TYPE_CHECKING: from collections.abc import Generator @@ -68,7 +68,6 @@ from airflow import DAG from airflow.models.dagwarning import DagWarning - from airflow.utils.types import ArgNotSet @contextlib.contextmanager @@ -231,14 +230,6 @@ def __init__( super().__init__() self.bundle_path = bundle_path self.bundle_name = bundle_name - include_examples = ( - include_examples - if isinstance(include_examples, bool) - else conf.getboolean("core", "LOAD_EXAMPLES") - ) - safe_mode = ( - safe_mode if isinstance(safe_mode, bool) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE") - ) dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder @@ -259,8 +250,14 @@ def __init__( if collect_dags: self.collect_dags( dag_folder=dag_folder, - include_examples=include_examples, - safe_mode=safe_mode, + include_examples=( + include_examples + if is_arg_set(include_examples) + else conf.getboolean("core", "LOAD_EXAMPLES") + ), + safe_mode=( + safe_mode if is_arg_set(safe_mode) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE") + ), ) # Should the extra operator link be loaded via plugins? # This flag is set to False in Scheduler so that Extra Operator links are not loaded diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 97598a3c2de6d..54ca61a75b361 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -69,6 +69,7 @@ from airflow.models.tasklog import LogTemplate from airflow.models.taskmap import TaskMap from airflow.sdk.definitions.deadline import DeadlineReference +from airflow.serialization.definitions.notset import NOTSET, ArgNotSet, is_arg_set from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES @@ -90,7 +91,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.strings import get_random_string from airflow.utils.thread_safe_dict import ThreadSafeDict -from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType +from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from typing import Literal, TypeAlias @@ -105,7 +106,6 @@ from airflow.models.taskinstancekey import TaskInstanceKey from airflow.sdk import DAG as SDKDAG from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG - from airflow.utils.types import ArgNotSet CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) AttributeValueType: TypeAlias = ( @@ -348,7 +348,7 @@ def __init__( self.conf = conf or {} if state is not None: self.state = state - if queued_at is NOTSET: + if not is_arg_set(queued_at): self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None elif queued_at is not None: self.queued_at = queued_at diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index a13eb4fe158dc..83c457eb772b8 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -156,13 +156,9 @@ def get( stacklevel=1, ) from airflow.sdk import Variable as TaskSDKVariable - from airflow.sdk.definitions._internal.types import NOTSET - var_val = TaskSDKVariable.get( - key, - default=NOTSET if default_var is cls.__NO_DEFAULT_SENTINEL else default_var, - deserialize_json=deserialize_json, - ) + default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var} + var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs) if isinstance(var_val, str): mask_secret(var_val, key) diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index bc9326b2b2f52..8da146ca3afab 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -27,11 +27,10 @@ from airflow.models.referencemixin import ReferenceMixin from airflow.models.xcom import XCOM_RETURN_KEY -from airflow.sdk.definitions._internal.types import ArgNotSet from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.serialization.definitions.notset import NOTSET, is_arg_set from airflow.utils.db import exists_query from airflow.utils.state import State -from airflow.utils.types import NOTSET __all__ = ["XComArg", "get_task_map_length"] @@ -150,7 +149,7 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Ses @get_task_map_length.register -def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session) -> int | None: from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap @@ -193,23 +192,23 @@ def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session) -> int | None: return get_task_map_length(xcom_arg.arg, run_id, session=session) @get_task_map_length.register -def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session) -> int | None: all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. - if isinstance(xcom_arg.fillvalue, ArgNotSet): - return min(ready_lengths) - return max(ready_lengths) + if is_arg_set(xcom_arg.fillvalue): + return max(ready_lengths) + return min(ready_lengths) @get_task_map_length.register -def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session) -> int | None: all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): diff --git a/airflow-core/src/airflow/serialization/definitions/notset.py b/airflow-core/src/airflow/serialization/definitions/notset.py index 0e2057c45d01e..a7731daed20c8 100644 --- a/airflow-core/src/airflow/serialization/definitions/notset.py +++ b/airflow-core/src/airflow/serialization/definitions/notset.py @@ -18,7 +18,23 @@ from __future__ import annotations -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from typing import TYPE_CHECKING, TypeVar -# TODO (GH-52141): Have different NOTSET and ArgNotSet in the scheduler. -__all__ = ["NOTSET", "ArgNotSet"] +if TYPE_CHECKING: + from typing_extensions import TypeIs + + T = TypeVar("T") + +__all__ = ["NOTSET", "ArgNotSet", "is_arg_set"] + + +class ArgNotSet: + """Sentinel type for annotations, useful when None is not viable.""" + + +NOTSET = ArgNotSet() +"""Sentinel value for argument default. See ``ArgNotSet``.""" + + +def is_arg_set(value: T | ArgNotSet) -> TypeIs[T]: + return not isinstance(value, ArgNotSet) diff --git a/airflow-core/src/airflow/serialization/definitions/param.py b/airflow-core/src/airflow/serialization/definitions/param.py index 8169b23f59c1d..733131f3eab47 100644 --- a/airflow-core/src/airflow/serialization/definitions/param.py +++ b/airflow-core/src/airflow/serialization/definitions/param.py @@ -22,7 +22,7 @@ import copy from typing import TYPE_CHECKING, Any -from airflow.serialization.definitions.notset import NOTSET, ArgNotSet +from airflow.serialization.definitions.notset import NOTSET, is_arg_set if TYPE_CHECKING: from collections.abc import Iterator, Mapping @@ -51,7 +51,7 @@ def resolve(self, *, raises: bool = False) -> Any: import jsonschema try: - if isinstance(value := self.value, ArgNotSet): + if not is_arg_set(value := self.value): raise ValueError("No value passed") jsonschema.validate(value, self.schema, format_checker=jsonschema.FormatChecker()) except Exception: diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index acb05228beffc..9a2a8efe96e30 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -122,7 +122,7 @@ from airflow.utils.module_loading import import_string, qualname from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.types import NOTSET, ArgNotSet, DagRunTriggeredByType, DagRunType +from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from inspect import Parameter @@ -736,7 +736,11 @@ def serialize( :meta private: """ - if cls._is_primitive(var): + from airflow.sdk.definitions._internal.types import is_arg_set + + if not is_arg_set(var): + return cls._encode(None, type_=DAT.ARG_NOT_SET) + elif cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. if isinstance(var, enum.Enum): return var.value @@ -867,8 +871,6 @@ def serialize( obj = cls.serialize(v, strict=strict) d[str(k)] = obj return cls._encode(d, type_=DAT.TASK_CONTEXT) - elif isinstance(var, ArgNotSet): - return cls._encode(None, type_=DAT.ARG_NOT_SET) else: return cls.default_serialization(strict, var) @@ -981,6 +983,8 @@ def deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.TASK_INSTANCE_KEY: return TaskInstanceKey(**var) elif type_ == DAT.ARG_NOT_SET: + from airflow.serialization.definitions.notset import NOTSET + return NOTSET elif type_ == DAT.DEADLINE_ALERT: return DeadlineAlert.deserialize_deadline_alert(var) diff --git a/airflow-core/src/airflow/utils/context.py b/airflow-core/src/airflow/utils/context.py index abf6bb1a53062..4793d628323f1 100644 --- a/airflow-core/src/airflow/utils/context.py +++ b/airflow-core/src/airflow/utils/context.py @@ -30,17 +30,15 @@ from sqlalchemy import select -from airflow.models.asset import ( - AssetModel, -) +from airflow.models.asset import AssetModel from airflow.sdk.definitions.context import Context from airflow.sdk.execution_time.context import ( ConnectionAccessor as ConnectionAccessorSDK, OutletEventAccessors as OutletEventAccessorsSDK, VariableAccessor as VariableAccessorSDK, ) +from airflow.serialization.definitions.notset import NOTSET, is_arg_set from airflow.utils.session import create_session -from airflow.utils.types import NOTSET if TYPE_CHECKING: from airflow.sdk.definitions.asset import Asset @@ -100,9 +98,9 @@ def __getattr__(self, key: str) -> Any: def get(self, key, default: Any = NOTSET) -> Any: from airflow.models.variable import Variable - if default is NOTSET: - return Variable.get(key, deserialize_json=self._deserialize_json) - return Variable.get(key, default, deserialize_json=self._deserialize_json) + if is_arg_set(default): + return Variable.get(key, default, deserialize_json=self._deserialize_json) + return Variable.get(key, deserialize_json=self._deserialize_json) class ConnectionAccessor(ConnectionAccessorSDK): diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index 4d79a28d41e77..0e4c5f325b131 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -30,7 +30,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.utils.types import NOTSET +from airflow.serialization.definitions.notset import is_arg_set if TYPE_CHECKING: from datetime import datetime @@ -283,13 +283,7 @@ def at_most_one(*args) -> bool: If user supplies an iterable, we raise ValueError and force them to unpack. """ - - def is_set(val): - if val is NOTSET: - return False - return bool(val) - - return sum(map(is_set, args)) in (0, 1) + return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1) def prune_dict(val: Any, mode="strict"): diff --git a/airflow-core/src/airflow/utils/types.py b/airflow-core/src/airflow/utils/types.py index 276901f94bcd9..4e1aa11ffd81c 100644 --- a/airflow-core/src/airflow/utils/types.py +++ b/airflow-core/src/airflow/utils/types.py @@ -14,19 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import enum -from typing import TYPE_CHECKING - -import airflow.sdk.definitions._internal.types - -if TYPE_CHECKING: - from typing import TypeAlias -ArgNotSet: TypeAlias = airflow.sdk.definitions._internal.types.ArgNotSet - -NOTSET = airflow.sdk.definitions._internal.types.NOTSET +from airflow.utils.deprecation_tools import add_deprecated_classes class DagRunType(str, enum.Enum): @@ -68,3 +61,14 @@ class DagRunTriggeredByType(enum.Enum): TIMETABLE = "timetable" # for timetable based triggering ASSET = "asset" # for asset_triggered run type BACKFILL = "backfill" + + +add_deprecated_classes( + { + __name__: { + "ArgNotSet": "airflow.serialization.definitions.notset.ArgNotSet", + "NOTSET": "airflow.serialization.definitions.notset.ArgNotSet", + }, + }, + package=__name__, +) diff --git a/airflow-core/tests/unit/models/test_xcom_arg.py b/airflow-core/tests/unit/models/test_xcom_arg.py index f5ce83df4a999..ee29354444932 100644 --- a/airflow-core/tests/unit/models/test_xcom_arg.py +++ b/airflow-core/tests/unit/models/test_xcom_arg.py @@ -21,7 +21,7 @@ from airflow.models.xcom_arg import XComArg from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.utils.types import NOTSET +from airflow.serialization.definitions.notset import NOTSET from tests_common.test_utils.db import clear_db_dags, clear_db_runs diff --git a/airflow-core/tests/unit/utils/test_helpers.py b/airflow-core/tests/unit/utils/test_helpers.py index 6f297904eb72a..6179acadfb94f 100644 --- a/airflow-core/tests/unit/utils/test_helpers.py +++ b/airflow-core/tests/unit/utils/test_helpers.py @@ -26,6 +26,7 @@ from airflow._shared.timezones import timezone from airflow.exceptions import AirflowException from airflow.jobs.base_job_runner import BaseJobRunner +from airflow.serialization.definitions.notset import NOTSET from airflow.utils import helpers from airflow.utils.helpers import ( at_most_one, @@ -35,7 +36,6 @@ prune_dict, validate_key, ) -from airflow.utils.types import NOTSET from tests_common.test_utils.db import clear_db_dags, clear_db_runs diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 96f1bfe814dbd..e70ebd0e6934c 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -883,7 +883,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: # This fixture is "called" early on in the pytest collection process, and # if we import airflow.* here the wrong (non-test) config will be loaded # and "baked" in to various constants - from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, NOTSET want_serialized = False want_activate_assets = True # Only has effect if want_serialized=True on Airflow 3. @@ -896,7 +896,6 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: (want_activate_assets,) = serialized_marker.args or (True,) from airflow.utils.log.logging_mixin import LoggingMixin - from airflow.utils.types import NOTSET class DagFactory(LoggingMixin, DagMaker): _own_session = False @@ -1465,9 +1464,8 @@ def create_task_instance(dag_maker: DagMaker, create_dummy_dag: CreateDummyDAG) Uses ``create_dummy_dag`` to create the dag structure. """ from airflow.providers.standard.operators.empty import EmptyOperator - from airflow.utils.types import NOTSET, ArgNotSet - from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, NOTSET, ArgNotSet def maker( logical_date: datetime | None | ArgNotSet = NOTSET, @@ -1574,7 +1572,7 @@ def __call__( @pytest.fixture def create_serialized_task_instance_of_operator(dag_maker: DagMaker) -> CreateTaskInstanceOfOperator: - from airflow.utils.types import NOTSET + from tests_common.test_utils.version_compat import NOTSET def _create_task_instance( operator_class, @@ -1594,7 +1592,7 @@ def _create_task_instance( @pytest.fixture def create_task_instance_of_operator(dag_maker: DagMaker) -> CreateTaskInstanceOfOperator: - from airflow.utils.types import NOTSET + from tests_common.test_utils.version_compat import NOTSET def _create_task_instance( operator_class, diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index ad093637b790a..e30c692278fe8 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -39,17 +39,18 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) - if AIRFLOW_V_3_1_PLUS: from airflow.sdk import PokeReturnValue, timezone from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.decorators import remove_task_decorator + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet XCOM_RETURN_KEY = BaseXCom.XCOM_RETURN_KEY else: from airflow.sensors.base import PokeReturnValue # type: ignore[no-redef] from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.utils.decorators import remove_task_decorator # type: ignore[no-redef] + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] from airflow.utils.xcom import XCOM_RETURN_KEY # type: ignore[no-redef] @@ -70,9 +71,11 @@ def get_sqlalchemy_version_tuple() -> tuple[int, int, int]: "AIRFLOW_V_3_0_1", "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS", + "NOTSET", "SQLALCHEMY_V_1_4", "SQLALCHEMY_V_2_0", "XCOM_RETURN_KEY", + "ArgNotSet", "PokeReturnValue", "remove_task_decorator", "timezone", diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index 5b275cb641f8b..87404dedb4167 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -43,15 +43,13 @@ from uuid import uuid4 if TYPE_CHECKING: + from aiobotocore.client import AioBaseClient from mypy_boto3_s3.service_resource import ( Bucket as S3Bucket, Object as S3ResourceObject, ) - from airflow.utils.types import ArgNotSet - - with suppress(ImportError): - from aiobotocore.client import AioBaseClient + from airflow.providers.amazon.version_compat import ArgNotSet from asgiref.sync import sync_to_async diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py index 7bc2c8214264d..ef8c30323a74c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.utils.types import NOTSET, ArgNotSet +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet, is_arg_set if TYPE_CHECKING: from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -71,9 +71,9 @@ def get_parameter_value(self, parameter: str, default: str | ArgNotSet = NOTSET) mask_secret(value) return value except self.conn.exceptions.ParameterNotFound: - if isinstance(default, ArgNotSet): - raise - return default + if is_arg_set(default): + return default + raise def get_command_invocation(self, command_id: str, instance_id: str) -> dict: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py index 29fb40d3e67ed..b293c52e3a7be 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py @@ -25,8 +25,8 @@ AwsHookType, aws_template_fields, ) +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet from airflow.providers.common.compat.sdk import BaseOperator -from airflow.utils.types import NOTSET, ArgNotSet class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py index c023bc04e14b3..b1296b0f9eb3e 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py @@ -57,8 +57,8 @@ waiter, ) from airflow.providers.amazon.aws.utils.waiter_with_logging import wait +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet from airflow.utils.helpers import exactly_one, prune_dict -from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py index b13634bc2bdd3..562e0816cea57 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py @@ -25,8 +25,8 @@ AwsHookType, aws_template_fields, ) +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet from airflow.providers.common.compat.sdk import BaseSensorOperator -from airflow.utils.types import NOTSET, ArgNotSet class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py index 612e57701cb2f..9be7fb54991cf 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py @@ -22,8 +22,8 @@ from collections.abc import Sequence from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet, is_arg_set from airflow.providers.common.compat.sdk import BaseOperator -from airflow.utils.types import NOTSET, ArgNotSet class AwsToAwsBaseOperator(BaseOperator): @@ -55,7 +55,7 @@ def __init__( self.source_aws_conn_id = source_aws_conn_id self.dest_aws_conn_id = dest_aws_conn_id self.source_aws_conn_id = source_aws_conn_id - if isinstance(dest_aws_conn_id, ArgNotSet): - self.dest_aws_conn_id = self.source_aws_conn_id - else: + if is_arg_set(dest_aws_conn_id): self.dest_aws_conn_id = dest_aws_conn_id + else: + self.dest_aws_conn_id = self.source_aws_conn_id diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index d50fcfd00aef6..a683b4c9e21e3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -36,8 +36,8 @@ from airflow.utils.helpers import prune_dict if TYPE_CHECKING: - from airflow.utils.context import Context - from airflow.utils.types import ArgNotSet + from airflow.providers.amazon.version_compat import ArgNotSet + from airflow.sdk import Context class JSONEncoder(json.JSONEncoder): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index a285af2f65b5b..0fe68099ace5c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -28,8 +28,8 @@ from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet, is_arg_set from airflow.providers.common.compat.sdk import BaseOperator -from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.utils.context import Context @@ -131,12 +131,12 @@ def __init__( # actually provide a connection note that, because we don't want to let the exception bubble up in # that case (since we're silently injecting a connection on their behalf). self._aws_conn_id: str | None - if isinstance(aws_conn_id, ArgNotSet): - self.conn_set = False - self._aws_conn_id = "aws_default" - else: + if is_arg_set(aws_conn_id): self.conn_set = True self._aws_conn_id = aws_conn_id + else: + self.conn_set = False + self._aws_conn_id = "aws_default" def _build_unload_query( self, credentials_block: str, select_query: str, s3_key: str, unload_options: str diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index ae36822976ba8..87d1d752e94bf 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -24,8 +24,8 @@ from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet, is_arg_set from airflow.providers.common.compat.sdk import BaseOperator -from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.utils.context import Context @@ -122,12 +122,12 @@ def __init__( # actually provide a connection note that, because we don't want to let the exception bubble up in # that case (since we're silently injecting a connection on their behalf). self._aws_conn_id: str | None - if isinstance(aws_conn_id, ArgNotSet): - self.conn_set = False - self._aws_conn_id = "aws_default" - else: + if is_arg_set(aws_conn_id): self.conn_set = True self._aws_conn_id = aws_conn_id + else: + self.conn_set = False + self._aws_conn_id = "aws_default" if self.redshift_data_api_kwargs: for arg in ["sql", "parameters"]: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py index faac0f90753e6..70d8254a316c1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py @@ -20,7 +20,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -from airflow.utils.types import NOTSET, ArgNotSet +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py index 3ed84db484351..e1f7bbeb0f4a7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -28,8 +28,8 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Connection diff --git a/providers/amazon/src/airflow/providers/amazon/version_compat.py b/providers/amazon/src/airflow/providers/amazon/version_compat.py index a7d116ec0433a..dc76a025ccf5c 100644 --- a/providers/amazon/src/airflow/providers/amazon/version_compat.py +++ b/providers/amazon/src/airflow/providers/amazon/version_compat.py @@ -20,9 +20,13 @@ # ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT # THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE # + from __future__ import annotations +import functools + +@functools.cache def get_base_airflow_version_tuple() -> tuple[int, int, int]: from packaging.version import Version @@ -36,8 +40,23 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1) +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk.definitions._internal.types import is_arg_set +except ImportError: + + def is_arg_set(value): # type: ignore[misc,no-redef] + return value is not NOTSET + + __all__ = [ "AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_1_1_PLUS", + "NOTSET", + "ArgNotSet", + "is_arg_set", ] diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py index 399245face35c..f873b6dba7086 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET LOGIN_USER = "login" LOGIN_PASSWORD = "password" diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py index b848f38c5b3e0..423c41a8b2e6e 100644 --- a/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py +++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py @@ -21,7 +21,7 @@ import pytest from airflow.providers.amazon.aws.notifications.ses import SesNotifier, send_ses_notification -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET TEST_EMAIL_PARAMS = { "mail_from": "from@test.com", diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py index b09098d0d4966..ef3abcfbacbb7 100644 --- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py +++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py @@ -21,7 +21,7 @@ import pytest from airflow.providers.amazon.aws.notifications.sns import SnsNotifier, send_sns_notification -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET PUBLISH_KWARGS = { "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName", diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py b/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py index 2c33e77f2e423..10a9d115f73bd 100644 --- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py +++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py @@ -21,7 +21,7 @@ import pytest from airflow.providers.amazon.aws.notifications.sqs import SqsNotifier, send_sqs_notification -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value") diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py index c7aeb0830ba48..fea16411a2909 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py @@ -29,7 +29,7 @@ ComprehendCreateDocumentClassifierOperator, ComprehendStartPiiEntitiesDetectionJobOperator, ) -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET from unit.amazon.aws.utils.test_template_fields import validate_template_fields diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py index 14438fdbde717..14b442d19c594 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py @@ -37,7 +37,7 @@ ) from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET from unit.amazon.aws.utils.test_template_fields import validate_template_fields diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py index e88880289bc2e..e43d7b9793c5f 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py @@ -31,7 +31,7 @@ EmrServerlessStartJobOperator, EmrServerlessStopApplicationOperator, ) -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET from unit.amazon.aws.utils.test_template_fields import validate_template_fields diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py index 8a116fd484b3f..119683407939a 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py @@ -35,12 +35,12 @@ EcsTaskStates, EcsTaskStateSensor, ) +from airflow.providers.amazon.version_compat import NOTSET try: from airflow.sdk import timezone except ImportError: from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -from airflow.utils.types import NOTSET _Operator = TypeVar("_Operator") TEST_CLUSTER_NAME = "fake-cluster" diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py b/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py index 40c1aba2420d1..4ac5ff28dc29e 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py @@ -23,7 +23,7 @@ import pytest from airflow.providers.amazon.aws.utils.identifiers import generate_uuid -from airflow.utils.types import NOTSET +from airflow.providers.amazon.version_compat import NOTSET class TestGenerateUuid: diff --git a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py index 363fa5e9bac0b..7aa2822f219dc 100644 --- a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py +++ b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py @@ -32,42 +32,48 @@ from time import sleep from typing import TYPE_CHECKING, Any -try: - from airflow.sdk import task, task_group -except ImportError: - # Airflow 2 path - from airflow.decorators import task, task_group # type: ignore[attr-defined,no-redef] from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.models.variable import Variable from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.execution_time.context import context_to_airflow_vars +try: + from airflow.sdk import task, task_group +except ImportError: + from airflow.decorators import task, task_group # type: ignore[attr-defined,no-redef] try: from airflow.sdk import BaseHook except ImportError: from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] -from airflow.sdk import Param - +try: + from airflow.sdk import Param +except ImportError: + from airflow.models import Param # type: ignore[attr-defined,no-redef] try: from airflow.sdk import TriggerRule except ImportError: - # Compatibility for Airflow < 3.1 from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] -from airflow.sdk.execution_time.context import context_to_airflow_vars -from airflow.utils.types import ArgNotSet - -if TYPE_CHECKING: - try: - from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance - except ImportError: - from airflow.models import TaskInstance # type: ignore[assignment] - from airflow.utils.context import Context - try: - from airflow.operators.python import PythonOperator + from airflow.providers.common.compat.standard.operators import PythonOperator +except ImportError: + from airflow.operators.python import PythonOperator # type: ignore[no-redef] +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet except ImportError: - from airflow.providers.common.compat.standard.operators import PythonOperator # type: ignore[no-redef] + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk.definitions._internal.types import is_arg_set +except ImportError: + + def is_arg_set(value): # type: ignore[misc,no-redef] + return value is not NOTSET + + +if TYPE_CHECKING: + from airflow.sdk import Context + from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance class CmdOperator(BaseOperator): @@ -163,7 +169,7 @@ def __init__( # When using the @task.command decorator, the command is not known until the underlying Python # callable is executed and therefore set to NOTSET initially. This flag is useful during execution to # determine whether the command value needs to re-rendered. - self._init_command_not_set = isinstance(self.command, ArgNotSet) + self._init_command_not_set = not is_arg_set(self.command) @staticmethod def refresh_command(ti: TaskInstance) -> None: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py index a8e00ff453201..54f7aa93c6eec 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -31,7 +31,11 @@ from airflow.providers.google.cloud.hooks.os_login import OSLoginHook from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers.ssh.hooks.ssh import SSHHook -from airflow.utils.types import NOTSET, ArgNotSet + +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] # Paramiko should be imported after airflow.providers.ssh. Then the import will fail with # cannot import "airflow.providers.ssh" and will be correctly discovered as optional feature diff --git a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py index c9cc797206e63..2ef1ce2efef67 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -35,17 +35,20 @@ from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.utils.types import NOTSET, ArgNotSet + +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] + +if not AIRFLOW_V_3_0_PLUS: + from airflow.utils.log.trigger_handler import ctx_indiv_trigger if TYPE_CHECKING: from google.auth.credentials import Credentials from airflow.models import TaskInstance - -if not AIRFLOW_V_3_0_PLUS: - from airflow.utils.log.trigger_handler import ctx_indiv_trigger - DEFAULT_LOGGER_NAME = "airflow" _GLOBAL_RESOURCE = Resource(type="global", labels={}) diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index b286a169adf17..f531dd504366b 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -31,10 +31,9 @@ from airflow.models import Connection from airflow.providers.postgres.dialects.postgres import PostgresDialect from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook -from airflow.utils.types import NOTSET from tests_common.test_utils.common_sql import mock_db_hook -from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4 +from tests_common.test_utils.version_compat import NOTSET, SQLALCHEMY_V_1_4 INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, description, host, {}, login, password, port, is_encrypted, is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)" diff --git a/providers/slack/src/airflow/providers/slack/utils/__init__.py b/providers/slack/src/airflow/providers/slack/utils/__init__.py index 6c59b85c0531b..7c46594034fb2 100644 --- a/providers/slack/src/airflow/providers/slack/utils/__init__.py +++ b/providers/slack/src/airflow/providers/slack/utils/__init__.py @@ -20,7 +20,10 @@ from collections.abc import Sequence from typing import Any -from airflow.utils.types import NOTSET +try: + from airflow.sdk.definitions._internal.types import NOTSET +except ImportError: + from airflow.utils.types import NOTSET # type: ignore[attr-defined,no-redef] class ConnectionExtraConfig: diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py index 55497d22d3eb6..b5860a931a535 100644 --- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py @@ -36,7 +36,18 @@ from airflow.exceptions import AirflowException from airflow.providers.common.compat.sdk import BaseHook from airflow.utils.platform import getuser -from airflow.utils.types import NOTSET, ArgNotSet + +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk.definitions._internal.types import is_arg_set +except ImportError: + + def is_arg_set(value): # type: ignore[misc,no-redef] + return value is not NOTSET + CMD_TIMEOUT = 10 @@ -438,9 +449,9 @@ def exec_ssh_client_command( self.log.info("Running command: %s", command) cmd_timeout: float | None - if not isinstance(timeout, ArgNotSet): + if is_arg_set(timeout): cmd_timeout = timeout - elif not isinstance(self.cmd_timeout, ArgNotSet): + elif is_arg_set(self.cmd_timeout): cmd_timeout = self.cmd_timeout else: cmd_timeout = CMD_TIMEOUT diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py index f2f53132376dd..3aef97df0c410 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py @@ -26,7 +26,11 @@ from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.ssh.hooks.ssh import SSHHook -from airflow.utils.types import NOTSET, ArgNotSet + +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from paramiko.client import SSHClient diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py b/providers/ssh/tests/unit/ssh/operators/test_ssh.py index 1d94fe5768376..5747a738fc843 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py @@ -30,11 +30,10 @@ from airflow.providers.common.compat.sdk import timezone from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow.utils.types import NOTSET from tests_common.test_utils.config import conf_vars from tests_common.test_utils.dag import sync_dag_to_db -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, NOTSET datetime = timezone.datetime diff --git a/providers/standard/src/airflow/providers/standard/operators/bash.py b/providers/standard/src/airflow/providers/standard/operators/bash.py index 533b1dee5fbb1..7c5f0269a967e 100644 --- a/providers/standard/src/airflow/providers/standard/operators/bash.py +++ b/providers/standard/src/airflow/providers/standard/operators/bash.py @@ -31,7 +31,8 @@ if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context - from airflow.utils.types import ArgNotSet + + from tests_common.test_utils.version_compat import ArgNotSet class BashOperator(BaseOperator): diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 3bec158c8b704..c0f8709fa87f9 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -42,7 +42,12 @@ from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator from airflow.utils.state import DagRunState -from airflow.utils.types import NOTSET, ArgNotSet, DagRunType +from airflow.utils.types import DagRunType + +try: + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +except ImportError: + from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py b/providers/standard/tests/unit/standard/decorators/test_bash.py index e3828db3f11f2..d868182f3b79d 100644 --- a/providers/standard/tests/unit/standard/decorators/test_bash.py +++ b/providers/standard/tests/unit/standard/decorators/test_bash.py @@ -41,7 +41,9 @@ else: # bad hack but does the job from airflow.decorators import task # type: ignore[attr-defined,no-redef] - from airflow.utils.types import NOTSET as SET_DURING_EXECUTION # type: ignore[assignment] + from airflow.utils.types import ( # type: ignore[attr-defined,no-redef] + NOTSET as SET_DURING_EXECUTION, # type: ignore[assignment] + ) if AIRFLOW_V_3_1_PLUS: from airflow.sdk import timezone else: diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 9787b85a6c607..cbc2bca2958a0 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -63,10 +63,15 @@ from airflow.providers.standard.utils.python_virtualenv import _execute_in_subprocess, prepare_virtualenv from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.types import NOTSET, DagRunType +from airflow.utils.types import DagRunType from tests_common.test_utils.db import clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_1, + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + NOTSET, +) if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/types.py b/task-sdk/src/airflow/sdk/definitions/_internal/types.py index 8ae8ef1b1cba4..47270f823fc73 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/types.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/types.py @@ -17,36 +17,37 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar +if TYPE_CHECKING: + from typing_extensions import TypeIs -class ArgNotSet: - """ - Sentinel type for annotations, useful when None is not viable. - - Use like this:: + from airflow.sdk.definitions._internal.node import DAGNode - def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool: - if arg is NOTSET: - return False - return True + T = TypeVar("T") +__all__ = [ + "NOTSET", + "SET_DURING_EXECUTION", + "ArgNotSet", + "SetDuringExecution", + "is_arg_set", + "validate_instance_args", +] - is_arg_passed() # False. - is_arg_passed(None) # True. - """ +try: + # If core and SDK exist together, use core to avoid identity issues. + from airflow.serialization.definitions.notset import NOTSET, ArgNotSet +except ModuleNotFoundError: - @staticmethod - def serialize(): - return "NOTSET" + class ArgNotSet: # type: ignore[no-redef] + """Sentinel type for annotations, useful when None is not viable.""" - @classmethod - def deserialize(cls): - return cls + NOTSET = ArgNotSet() # type: ignore[no-redef] -NOTSET = ArgNotSet() -"""Sentinel value for argument default. See ``ArgNotSet``.""" +def is_arg_set(value: T | ArgNotSet) -> TypeIs[T]: + return not isinstance(value, ArgNotSet) class SetDuringExecution(ArgNotSet): @@ -61,10 +62,6 @@ def serialize() -> str: """Sentinel value for argument default. See ``SetDuringExecution``.""" -if TYPE_CHECKING: - from airflow.sdk.definitions._internal.node import DAGNode - - def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: """Validate that the instance has the expected types for the arguments.""" from airflow.sdk.definitions.taskgroup import TaskGroup diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 4e26fce37e779..6fc32c188b891 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -46,7 +46,7 @@ from airflow.sdk import TaskInstanceState, TriggerRule from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.node import validate_key -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set from airflow.sdk.definitions.asset import AssetAll, BaseAsset from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.deadline import DeadlineAlert @@ -1197,7 +1197,7 @@ def test( self.validate() # Allow users to explicitly pass None. If it isn't set, we default to current time. - logical_date = logical_date if not isinstance(logical_date, ArgNotSet) else timezone.utcnow() + logical_date = logical_date if is_arg_set(logical_date) else timezone.utcnow() log.debug("Clearing existing task instances for logical date %s", logical_date) # TODO: Replace with calling client.dag_run.clear in Execution API at some point diff --git a/task-sdk/src/airflow/sdk/definitions/param.py b/task-sdk/src/airflow/sdk/definitions/param.py index 2c853ce1ffc90..5c0136d7de5e6 100644 --- a/task-sdk/src/airflow/sdk/definitions/param.py +++ b/task-sdk/src/airflow/sdk/definitions/param.py @@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException, ParamValidationError from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set if TYPE_CHECKING: from airflow.sdk.definitions.context import Context @@ -90,7 +90,7 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: if value is not NOTSET: self._check_json(value) final_val = self.value if value is NOTSET else value - if isinstance(final_val, ArgNotSet): + if not is_arg_set(final_val): if suppress_exception: return None raise ParamValidationError("No value passed and Param has no default value") diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index fa540333cf317..a674a47e19bde 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -31,7 +31,7 @@ from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence from airflow.sdk.execution_time.xcom import BaseXCom @@ -347,7 +347,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: default=NOTSET, map_indexes=map_indexes, ) - if not isinstance(result, ArgNotSet): + if is_arg_set(result): return result if self.key == BaseXCom.XCOM_RETURN_KEY: return None @@ -452,9 +452,9 @@ def __getitem__(self, index: Any) -> Any: def __len__(self) -> int: lengths = (len(v) for v in self.values) - if isinstance(self.fillvalue, ArgNotSet): - return min(lengths) - return max(lengths) + if is_arg_set(self.fillvalue): + return max(lengths) + return min(lengths) @attrs.define @@ -474,15 +474,15 @@ def __repr__(self) -> str: args_iter = iter(self.args) first = repr(next(args_iter)) rest = ", ".join(repr(arg) for arg in args_iter) - if isinstance(self.fillvalue, ArgNotSet): - return f"{first}.zip({rest})" - return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + if is_arg_set(self.fillvalue): + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + return f"{first}.zip({rest})" def _serialize(self) -> dict[str, Any]: args = [serialize_xcom_arg(arg) for arg in self.args] - if isinstance(self.fillvalue, ArgNotSet): - return {"args": args} - return {"args": args, "fillvalue": self.fillvalue} + if is_arg_set(self.fillvalue): + return {"args": args, "fillvalue": self.fillvalue} + return {"args": args} def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: @@ -602,9 +602,9 @@ def _(xcom_arg: ZipXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. - if isinstance(xcom_arg.fillvalue, ArgNotSet): - return min(ready_lengths) - return max(ready_lengths) + if is_arg_set(xcom_arg.fillvalue): + return max(ready_lengths) + return min(ready_lengths) @get_task_map_length.register diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 26c1e5ee762b2..74d085f245a8e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -54,7 +54,7 @@ from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params @@ -360,7 +360,7 @@ def xcom_pull( task_ids = [task_ids] # If map_indexes is not specified, pull xcoms from all map indexes for each task - if isinstance(map_indexes, ArgNotSet): + if not is_arg_set(map_indexes): xcoms: list[Any] = [] for t_id in task_ids: values = XCom.get_all( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index cb2687a967693..56ee821dd2f89 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -66,7 +66,7 @@ TIRunContext, ) from airflow.sdk.bases.xcom import BaseXCom -from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet +from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model from airflow.sdk.definitions.param import DagParam from airflow.sdk.exceptions import ErrorType @@ -1570,9 +1570,7 @@ def mock_send_side_effect(*args, **kwargs): for task_id_raw in task_ids: # Without task_ids (or None) expected behavior is to pull with calling task_id - task_id = ( - test_task_id if task_id_raw is None or isinstance(task_id_raw, ArgNotSet) else task_id_raw - ) + task_id = task_id_raw if is_arg_set(task_id_raw) and task_id_raw is not None else test_task_id for map_index in map_indexes: if map_index == NOTSET: