Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions airflow-core/src/airflow/dag_processing/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.listeners.listener import get_listener_manager
from airflow.serialization.definitions.notset import NOTSET, ArgNotSet, is_arg_set
from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.utils.docs import get_docs_url
from airflow.utils.file import (
Expand All @@ -59,7 +60,6 @@
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from collections.abc import Generator
Expand All @@ -68,7 +68,6 @@

from airflow import DAG
from airflow.models.dagwarning import DagWarning
from airflow.utils.types import ArgNotSet


@contextlib.contextmanager
Expand Down Expand Up @@ -231,14 +230,6 @@ def __init__(
super().__init__()
self.bundle_path = bundle_path
self.bundle_name = bundle_name
include_examples = (
include_examples
if isinstance(include_examples, bool)
else conf.getboolean("core", "LOAD_EXAMPLES")
)
safe_mode = (
safe_mode if isinstance(safe_mode, bool) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE")
)

dag_folder = dag_folder or settings.DAGS_FOLDER
self.dag_folder = dag_folder
Expand All @@ -259,8 +250,14 @@ def __init__(
if collect_dags:
self.collect_dags(
dag_folder=dag_folder,
include_examples=include_examples,
safe_mode=safe_mode,
include_examples=(
include_examples
if is_arg_set(include_examples)
else conf.getboolean("core", "LOAD_EXAMPLES")
),
safe_mode=(
safe_mode if is_arg_set(safe_mode) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE")
),
)
# Should the extra operator link be loaded via plugins?
# This flag is set to False in Scheduler so that Extra Operator links are not loaded
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
from airflow.sdk.definitions.deadline import DeadlineReference
from airflow.serialization.definitions.notset import NOTSET, ArgNotSet, is_arg_set
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
Expand All @@ -90,7 +91,7 @@
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.strings import get_random_string
from airflow.utils.thread_safe_dict import ThreadSafeDict
from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
from typing import Literal, TypeAlias
Expand All @@ -105,7 +106,6 @@
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.sdk import DAG as SDKDAG
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.utils.types import ArgNotSet

CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
AttributeValueType: TypeAlias = (
Expand Down Expand Up @@ -348,7 +348,7 @@ def __init__(
self.conf = conf or {}
if state is not None:
self.state = state
if queued_at is NOTSET:
if not is_arg_set(queued_at):
self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None
elif queued_at is not None:
self.queued_at = queued_at
Expand Down
8 changes: 2 additions & 6 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,9 @@ def get(
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable
from airflow.sdk.definitions._internal.types import NOTSET

var_val = TaskSDKVariable.get(
key,
default=NOTSET if default_var is cls.__NO_DEFAULT_SENTINEL else default_var,
deserialize_json=deserialize_json,
)
default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var}
var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs)
if isinstance(var_val, str):
mask_secret(var_val, key)

Expand Down
17 changes: 8 additions & 9 deletions airflow-core/src/airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@

from airflow.models.referencemixin import ReferenceMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.sdk.definitions._internal.types import ArgNotSet
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
from airflow.utils.db import exists_query
from airflow.utils.state import State
from airflow.utils.types import NOTSET

__all__ = ["XComArg", "get_task_map_length"]

Expand Down Expand Up @@ -150,7 +149,7 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Ses


@get_task_map_length.register
def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session):
def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session) -> int | None:
from airflow.models.mappedoperator import is_mapped
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -193,23 +192,23 @@ def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session):


@get_task_map_length.register
def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session):
def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session) -> int | None:
return get_task_map_length(xcom_arg.arg, run_id, session=session)


@get_task_map_length.register
def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session):
def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session) -> int | None:
all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(xcom_arg.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
if isinstance(xcom_arg.fillvalue, ArgNotSet):
return min(ready_lengths)
return max(ready_lengths)
if is_arg_set(xcom_arg.fillvalue):
return max(ready_lengths)
return min(ready_lengths)


@get_task_map_length.register
def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session):
def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session) -> int | None:
all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(xcom_arg.args):
Expand Down
22 changes: 19 additions & 3 deletions airflow-core/src/airflow/serialization/definitions/notset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,23 @@

from __future__ import annotations

from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
from typing import TYPE_CHECKING, TypeVar

# TODO (GH-52141): Have different NOTSET and ArgNotSet in the scheduler.
__all__ = ["NOTSET", "ArgNotSet"]
if TYPE_CHECKING:
from typing_extensions import TypeIs

T = TypeVar("T")

__all__ = ["NOTSET", "ArgNotSet", "is_arg_set"]


class ArgNotSet:
"""Sentinel type for annotations, useful when None is not viable."""


NOTSET = ArgNotSet()
"""Sentinel value for argument default. See ``ArgNotSet``."""


def is_arg_set(value: T | ArgNotSet) -> TypeIs[T]:
return not isinstance(value, ArgNotSet)
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/serialization/definitions/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import copy
from typing import TYPE_CHECKING, Any

from airflow.serialization.definitions.notset import NOTSET, ArgNotSet
from airflow.serialization.definitions.notset import NOTSET, is_arg_set

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
Expand Down Expand Up @@ -51,7 +51,7 @@ def resolve(self, *, raises: bool = False) -> Any:
import jsonschema

try:
if isinstance(value := self.value, ArgNotSet):
if not is_arg_set(value := self.value):
raise ValueError("No value passed")
jsonschema.validate(value, self.schema, format_checker=jsonschema.FormatChecker())
except Exception:
Expand Down
12 changes: 8 additions & 4 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunTriggeredByType, DagRunType
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
from inspect import Parameter
Expand Down Expand Up @@ -736,7 +736,11 @@ def serialize(

:meta private:
"""
if cls._is_primitive(var):
from airflow.sdk.definitions._internal.types import is_arg_set

if not is_arg_set(var):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
elif cls._is_primitive(var):
# enum.IntEnum is an int instance, it causes json dumps error so we use its value.
if isinstance(var, enum.Enum):
return var.value
Expand Down Expand Up @@ -867,8 +871,6 @@ def serialize(
obj = cls.serialize(v, strict=strict)
d[str(k)] = obj
return cls._encode(d, type_=DAT.TASK_CONTEXT)
elif isinstance(var, ArgNotSet):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
else:
return cls.default_serialization(strict, var)

Expand Down Expand Up @@ -981,6 +983,8 @@ def deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.TASK_INSTANCE_KEY:
return TaskInstanceKey(**var)
elif type_ == DAT.ARG_NOT_SET:
from airflow.serialization.definitions.notset import NOTSET

return NOTSET
elif type_ == DAT.DEADLINE_ALERT:
return DeadlineAlert.deserialize_deadline_alert(var)
Expand Down
12 changes: 5 additions & 7 deletions airflow-core/src/airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@

from sqlalchemy import select

from airflow.models.asset import (
AssetModel,
)
from airflow.models.asset import AssetModel
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import (
ConnectionAccessor as ConnectionAccessorSDK,
OutletEventAccessors as OutletEventAccessorsSDK,
VariableAccessor as VariableAccessorSDK,
)
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.definitions.asset import Asset
Expand Down Expand Up @@ -100,9 +98,9 @@ def __getattr__(self, key: str) -> Any:
def get(self, key, default: Any = NOTSET) -> Any:
from airflow.models.variable import Variable

if default is NOTSET:
return Variable.get(key, deserialize_json=self._deserialize_json)
return Variable.get(key, default, deserialize_json=self._deserialize_json)
if is_arg_set(default):
return Variable.get(key, default, deserialize_json=self._deserialize_json)
return Variable.get(key, deserialize_json=self._deserialize_json)


class ConnectionAccessor(ConnectionAccessorSDK):
Expand Down
10 changes: 2 additions & 8 deletions airflow-core/src/airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils.types import NOTSET
from airflow.serialization.definitions.notset import is_arg_set

if TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -283,13 +283,7 @@ def at_most_one(*args) -> bool:

If user supplies an iterable, we raise ValueError and force them to unpack.
"""

def is_set(val):
if val is NOTSET:
return False
return bool(val)

return sum(map(is_set, args)) in (0, 1)
return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1)


def prune_dict(val: Any, mode="strict"):
Expand Down
22 changes: 13 additions & 9 deletions airflow-core/src/airflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import enum
from typing import TYPE_CHECKING

import airflow.sdk.definitions._internal.types

if TYPE_CHECKING:
from typing import TypeAlias

ArgNotSet: TypeAlias = airflow.sdk.definitions._internal.types.ArgNotSet

NOTSET = airflow.sdk.definitions._internal.types.NOTSET
from airflow.utils.deprecation_tools import add_deprecated_classes


class DagRunType(str, enum.Enum):
Expand Down Expand Up @@ -68,3 +61,14 @@ class DagRunTriggeredByType(enum.Enum):
TIMETABLE = "timetable" # for timetable based triggering
ASSET = "asset" # for asset_triggered run type
BACKFILL = "backfill"


add_deprecated_classes(
{
__name__: {
"ArgNotSet": "airflow.serialization.definitions.notset.ArgNotSet",
"NOTSET": "airflow.serialization.definitions.notset.ArgNotSet",
},
},
package=__name__,
)
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/models/test_xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.utils.types import NOTSET
from airflow.serialization.definitions.notset import NOTSET

from tests_common.test_utils.db import clear_db_dags, clear_db_runs

Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow._shared.timezones import timezone
from airflow.exceptions import AirflowException
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.serialization.definitions.notset import NOTSET
from airflow.utils import helpers
from airflow.utils.helpers import (
at_most_one,
Expand All @@ -35,7 +36,6 @@
prune_dict,
validate_key,
)
from airflow.utils.types import NOTSET

from tests_common.test_utils.db import clear_db_dags, clear_db_runs

Expand Down
Loading
Loading