diff --git a/airflow-core/.pre-commit-config.yaml b/airflow-core/.pre-commit-config.yaml index 9fa00dcf37077..52b7d39875b80 100644 --- a/airflow-core/.pre-commit-config.yaml +++ b/airflow-core/.pre-commit-config.yaml @@ -375,7 +375,9 @@ repos: ^src/airflow/plugins_manager\.py$| ^src/airflow/providers_manager\.py$| ^src/airflow/secrets/__init__.py$| - ^src/airflow/serialization/definitions/[_a-z]+\.py$| + ^src/airflow/serialization/decoders\.py$| + ^src/airflow/serialization/definitions/[_/a-z]+\.py$| + ^src/airflow/serialization/encoders\.py$| ^src/airflow/serialization/enums\.py$| ^src/airflow/serialization/helpers\.py$| ^src/airflow/serialization/serialized_objects\.py$| diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 52f1e1428120c..32fce4553d5e3 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -237,16 +237,18 @@ def _save_dot_to_file(dot: Dot, filename: str) -> None: print(f"File {filename} saved") -def _get_dagbag_dag_details(dag: DAG, session: Session) -> dict: +def _get_dagbag_dag_details(dag: DAG) -> dict: """Return a dagbag dag details dict.""" - dag_model: DagModel | None = session.get(DagModel, dag.dag_id) + from airflow.serialization.encoders import coerce_to_core_timetable + + core_timetable = coerce_to_core_timetable(dag.timetable) return { "dag_id": dag.dag_id, "dag_display_name": dag.dag_display_name, - "bundle_name": dag_model.bundle_name if dag_model else None, - "bundle_version": dag_model.bundle_version if dag_model else None, - "is_paused": dag_model.is_paused if dag_model else None, - "is_stale": dag_model.is_stale if dag_model else None, + "bundle_name": None, + "bundle_version": None, + "is_paused": None, + "is_stale": None, "last_parsed_time": None, "last_parse_duration": None, "last_expired": None, @@ -255,8 +257,8 @@ def _get_dagbag_dag_details(dag: DAG, session: Session) -> dict: "file_token": None, "owners": dag.owner, "description": dag.description, - "timetable_summary": dag.timetable.summary, - "timetable_description": dag.timetable.description, + "timetable_summary": core_timetable.summary, + "timetable_description": core_timetable.description, "tags": dag.tags, "max_active_tasks": dag.max_active_tasks, "max_active_runs": dag.max_active_runs, @@ -401,11 +403,10 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None: ) def get_dag_detail(dag: DAG) -> dict: - dag_model = DagModel.get_dagmodel(dag.dag_id, session=session) - if dag_model: + if dag_model := DagModel.get_dagmodel(dag.dag_id, session=session): dag_detail = DAGResponse.model_validate(dag_model, from_attributes=True).model_dump() else: - dag_detail = _get_dagbag_dag_details(dag, session) + dag_detail = _get_dagbag_dag_details(dag) if not cols: return dag_detail return {col: dag_detail[col] for col in cols if col in DAG_DETAIL_FIELDS} diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index 942ea15ab9e61..0caa4b4cb2b48 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -934,7 +934,7 @@ def add_task_asset_references( def add_asset_trigger_references( self, assets: dict[tuple[str, str], AssetModel], *, session: Session ) -> None: - from airflow.serialization.serialized_objects import _encode_trigger + from airflow.serialization.encoders import encode_trigger # Update references from assets being used refs_to_add: dict[tuple[str, str], set[int]] = {} @@ -948,7 +948,7 @@ def add_asset_trigger_references( # If the asset belong to a DAG not active or paused, consider there is no watcher associated to it asset_watcher_triggers = ( [ - {**_encode_trigger(watcher.trigger), "watcher_name": watcher.name} + {**encode_trigger(watcher.trigger), "watcher_name": watcher.name} for watcher in asset.watchers ] if name_uri in active_assets diff --git a/airflow-core/src/airflow/example_dags/example_assets.py b/airflow-core/src/airflow/example_dags/example_assets.py index 3ab372112585c..ed1e61c489158 100644 --- a/airflow-core/src/airflow/example_dags/example_assets.py +++ b/airflow-core/src/airflow/example_dags/example_assets.py @@ -56,9 +56,7 @@ import pendulum from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk import DAG, Asset -from airflow.timetables.assets import AssetOrTimeSchedule -from airflow.timetables.trigger import CronTriggerTimetable +from airflow.sdk import DAG, Asset, AssetOrTimeSchedule, CronTriggerTimetable dag1_asset = Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}) dag2_asset = Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}) diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index 3790941229101..ea45e8942ca37 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -34,8 +34,9 @@ from airflow.sdk.exceptions import ( AirflowException, AirflowNotFoundException, - AirflowRescheduleException, - TaskNotFound, + AirflowRescheduleException as AirflowRescheduleException, + AirflowTimetableInvalid as AirflowTimetableInvalid, + TaskNotFound as TaskNotFound, ) except ModuleNotFoundError: # When _AIRFLOW__AS_LIBRARY is set, airflow.sdk may not be installed. @@ -43,18 +44,15 @@ class AirflowException(Exception): # type: ignore[no-redef] """Base exception for Airflow errors.""" - pass - class AirflowNotFoundException(AirflowException): # type: ignore[no-redef] """Raise when a requested object is not found.""" - pass + class AirflowTimetableInvalid(AirflowException): # type: ignore[no-redef] + """Raise when a DAG has an invalid timetable.""" class TaskNotFound(AirflowException): # type: ignore[no-redef] """Raise when a Task is not available in the system.""" - pass - class AirflowRescheduleException(AirflowException): # type: ignore[no-redef] """ Raise when the task should be re-scheduled at a later time. @@ -120,10 +118,6 @@ class AirflowClusterPolicyError(AirflowException): """Raise for a Cluster Policy other than AirflowClusterPolicyViolation or AirflowClusterPolicySkipDag.""" -class AirflowTimetableInvalid(AirflowException): - """Raise when a DAG has an invalid timetable.""" - - class DagNotFound(AirflowNotFoundException): """Raise when a DAG is not available in the system.""" @@ -308,23 +302,23 @@ class AirflowClearRunningTaskException(AirflowException): _DEPRECATED_EXCEPTIONS = { - "AirflowTaskTerminated": "airflow.sdk.exceptions.AirflowTaskTerminated", - "DuplicateTaskIdFound": "airflow.sdk.exceptions.DuplicateTaskIdFound", - "FailFastDagInvalidTriggerRule": "airflow.sdk.exceptions.FailFastDagInvalidTriggerRule", - "TaskAlreadyInTaskGroup": "airflow.sdk.exceptions.TaskAlreadyInTaskGroup", - "TaskDeferralTimeout": "airflow.sdk.exceptions.TaskDeferralTimeout", - "XComNotFound": "airflow.sdk.exceptions.XComNotFound", - "DownstreamTasksSkipped": "airflow.sdk.exceptions.DownstreamTasksSkipped", - "AirflowSensorTimeout": "airflow.sdk.exceptions.AirflowSensorTimeout", - "DagRunTriggerException": "airflow.sdk.exceptions.DagRunTriggerException", - "TaskDeferralError": "airflow.sdk.exceptions.TaskDeferralError", - "AirflowDagCycleException": "airflow.sdk.exceptions.AirflowDagCycleException", - "AirflowInactiveAssetInInletOrOutletException": "airflow.sdk.exceptions.AirflowInactiveAssetInInletOrOutletException", - "AirflowSkipException": "airflow.sdk.exceptions.AirflowSkipException", - "AirflowTaskTimeout": "airflow.sdk.exceptions.AirflowTaskTimeout", - "AirflowFailException": "airflow.sdk.exceptions.AirflowFailException", - "ParamValidationError": "airflow.sdk.exceptions.ParamValidationError", - "TaskDeferred": "airflow.sdk.exceptions.TaskDeferred", + "AirflowDagCycleException", + "AirflowFailException", + "AirflowInactiveAssetInInletOrOutletException", + "AirflowSensorTimeout", + "AirflowSkipException", + "AirflowTaskTerminated", + "AirflowTaskTimeout", + "DagRunTriggerException", + "DownstreamTasksSkipped", + "DuplicateTaskIdFound", + "FailFastDagInvalidTriggerRule", + "ParamValidationError", + "TaskAlreadyInTaskGroup", + "TaskDeferralError", + "TaskDeferralTimeout", + "TaskDeferred", + "XComNotFound", } @@ -336,7 +330,7 @@ def __getattr__(name: str): from airflow import DeprecatedImportWarning from airflow.utils.module_loading import import_string - target_path = _DEPRECATED_EXCEPTIONS[name] + target_path = f"airflow.sdk.exceptions.{name}" warnings.warn( f"airflow.exceptions.{name} is deprecated and will be removed in a future version. Use {target_path} instead.", DeprecatedImportWarning, diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5833593ee820e..b56a58ca6ccd9 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -937,7 +937,7 @@ async def create_triggers(self): await asyncio.sleep(0) try: - from airflow.serialization.serialized_objects import smart_decode_trigger_kwargs + from airflow.serialization.decoders import smart_decode_trigger_kwargs # Decrypt and clean trigger kwargs before for execution # Note: We only clean up serialization artifacts (__var, __type keys) here, diff --git a/airflow-core/src/airflow/serialization/decoders.py b/airflow-core/src/airflow/serialization/decoders.py new file mode 100644 index 0000000000000..afd6d297f3f42 --- /dev/null +++ b/airflow-core/src/airflow/serialization/decoders.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, TypeVar + +import dateutil.relativedelta + +from airflow.sdk import ( # TODO: Implement serialized assets. + Asset, + AssetAlias, + AssetAll, + AssetAny, +) +from airflow.serialization.definitions.assets import SerializedAssetWatcher +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.helpers import find_registered_custom_timetable, is_core_timetable_import_path +from airflow.utils.module_loading import import_string + +if TYPE_CHECKING: + from airflow.sdk.definitions.asset import BaseAsset + from airflow.timetables.base import Timetable as CoreTimetable + +R = TypeVar("R") + + +def decode_relativedelta(var: dict[str, Any]) -> dateutil.relativedelta.relativedelta: + """Dencode a relativedelta object.""" + if "weekday" in var: + var["weekday"] = dateutil.relativedelta.weekday(*var["weekday"]) + return dateutil.relativedelta.relativedelta(**var) + + +def decode_interval(value: int | dict) -> datetime.timedelta | dateutil.relativedelta.relativedelta: + if isinstance(value, dict): + return decode_relativedelta(value) + return datetime.timedelta(seconds=value) + + +def decode_run_immediately(value: bool | float) -> bool | datetime.timedelta: + if isinstance(value, float): + return datetime.timedelta(seconds=value) + return value + + +def smart_decode_trigger_kwargs(d): + """ + Slightly clean up kwargs for display or execution. + + This detects one level of BaseSerialization and tries to deserialize the + content, removing some __type __var ugliness when the value is displayed + in UI to the user and/or while execution. + """ + from airflow.serialization.serialized_objects import BaseSerialization + + if not isinstance(d, dict) or Encoding.TYPE not in d: + return d + return BaseSerialization.deserialize(d) + + +def decode_asset(var: dict[str, Any]): + watchers = var.get("watchers", []) + return Asset( + name=var["name"], + uri=var["uri"], + group=var["group"], + extra=var["extra"], + watchers=[ + SerializedAssetWatcher( + name=watcher["name"], + trigger={ + "classpath": watcher["trigger"]["classpath"], + "kwargs": smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]), + }, + ) + for watcher in watchers + ], + ) + + +def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: + """ + Decode a previously serialized asset condition. + + :meta private: + """ + match var["__type"]: + case DAT.ASSET: + return decode_asset(var) + case DAT.ASSET_ALL: + return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) + case DAT.ASSET_ANY: + return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) + case DAT.ASSET_ALIAS: + return AssetAlias(name=var["name"], group=var["group"]) + case DAT.ASSET_REF: + return Asset.ref(**{k: v for k, v in var.items() if k != "__type"}) + case data_type: + raise ValueError(f"deserialization not implemented for DAT {data_type!r}") + + +def decode_timetable(var: dict[str, Any]) -> CoreTimetable: + """ + Decode a previously serialized timetable. + + Most of the deserialization logic is delegated to the actual type, which + we import from string. + + :meta private: + """ + if is_core_timetable_import_path(importable_string := var[Encoding.TYPE]): + timetable_type: type[CoreTimetable] = import_string(importable_string) + else: + timetable_type = find_registered_custom_timetable(importable_string) + return timetable_type.deserialize(var[Encoding.VAR]) diff --git a/airflow-core/src/airflow/serialization/definitions/assets.py b/airflow-core/src/airflow/serialization/definitions/assets.py new file mode 100644 index 0000000000000..d9cf6f9a82c41 --- /dev/null +++ b/airflow-core/src/airflow/serialization/definitions/assets.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk import AssetWatcher # TODO: Implement serialized assets. + + +class SerializedAssetWatcher(AssetWatcher): + """JSON serializable representation of an asset watcher.""" + + trigger: dict diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py new file mode 100644 index 0000000000000..64a66dbe584b2 --- /dev/null +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import datetime +import functools +from typing import TYPE_CHECKING, Any, TypeVar, overload + +import attrs +import pendulum + +from airflow.sdk import ( + Asset, + AssetAlias, + AssetAll, + AssetAny, + AssetOrTimeSchedule, + CronDataIntervalTimetable, + CronTriggerTimetable, + DeltaDataIntervalTimetable, + DeltaTriggerTimetable, + EventsTimetable, + MultipleCronTriggerTimetable, +) +from airflow.sdk.bases.timetable import BaseTimetable +from airflow.sdk.definitions.asset import AssetRef +from airflow.sdk.definitions.timetables.assets import AssetTriggeredTimetable +from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.helpers import find_registered_custom_timetable, is_core_timetable_import_path +from airflow.timetables.base import Timetable as CoreTimetable +from airflow.utils.docs import get_docs_url +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from dateutil.relativedelta import relativedelta + + from airflow.sdk.definitions.asset import BaseAsset + from airflow.triggers.base import BaseEventTrigger + + T = TypeVar("T") + + +def encode_relativedelta(var: relativedelta) -> dict[str, Any]: + """Encode a relativedelta object.""" + encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} + if var.weekday and var.weekday.n: + # Every n'th Friday for example + encoded["weekday"] = [var.weekday.weekday, var.weekday.n] + elif var.weekday: + encoded["weekday"] = [var.weekday.weekday] + return encoded + + +def encode_timezone(var: str | pendulum.Timezone | pendulum.FixedTimezone) -> str | int: + """ + Encode a Pendulum Timezone for serialization. + + Airflow only supports timezone objects that implements Pendulum's Timezone + interface. We try to keep as much information as possible to make conversion + round-tripping possible (see ``decode_timezone``). We need to special-case + UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as + 0 without the special case), but passing 0 into ``pendulum.timezone`` does + not give us UTC (but ``+00:00``). + """ + if isinstance(var, str): + return var + if isinstance(var, pendulum.FixedTimezone): + if var.offset == 0: + return "UTC" + return var.offset + if isinstance(var, pendulum.Timezone): + return var.name + raise ValueError( + f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. " + f"See {get_docs_url('timezone.html#time-zone-aware-dags')}" + ) + + +def encode_interval(interval: datetime.timedelta | relativedelta) -> float | dict: + if isinstance(interval, datetime.timedelta): + return interval.total_seconds() + return encode_relativedelta(interval) + + +def encode_run_immediately(value: bool | datetime.timedelta) -> bool | float: + if isinstance(value, datetime.timedelta): + return value.total_seconds() + return value + + +def encode_trigger(trigger: BaseEventTrigger | dict): + from airflow.serialization.serialized_objects import BaseSerialization + + def _ensure_serialized(d): + """ + Make sure the kwargs dict is JSON-serializable. + + This is done with BaseSerialization logic. A simple check is added to + ensure we don't double-serialize, which is possible when a trigger goes + through multiple serialization layers. + """ + if isinstance(d, dict) and Encoding.TYPE in d: + return d + return BaseSerialization.serialize(d) + + if isinstance(trigger, dict): + classpath = trigger["classpath"] + kwargs = trigger["kwargs"] + else: + classpath, kwargs = trigger.serialize() + return { + "classpath": classpath, + "kwargs": {k: _ensure_serialized(v) for k, v in kwargs.items()}, + } + + +def encode_asset_condition(a: BaseAsset) -> dict[str, Any]: + """ + Encode an asset condition. + + :meta private: + """ + d: dict[str, Any] + match a: + case Asset(): + d = {"__type": DAT.ASSET, "name": a.name, "uri": a.uri, "group": a.group, "extra": a.extra} + if a.watchers: + d["watchers"] = [{"name": w.name, "trigger": encode_trigger(w.trigger)} for w in a.watchers] + return d + case AssetAlias(): + return {"__type": DAT.ASSET_ALIAS, "name": a.name, "group": a.group} + case AssetAll(): + return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x) for x in a.objects]} + case AssetAny(): + return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x) for x in a.objects]} + case AssetRef(): + return {"__type": DAT.ASSET_REF, **attrs.asdict(a)} + raise ValueError(f"serialization not implemented for {type(a).__name__!r}") + + +def _get_serialized_timetable_import_path(var: BaseTimetable | CoreTimetable) -> str: + # Find SDK classes. + with contextlib.suppress(KeyError): + return _serializer.BUILTIN_TIMETABLES[var_type := type(var)] + + # Check Core classes. + if is_core_timetable_import_path(importable_string := qualname(var_type)): + return importable_string + + # Find user-registered classes. + find_registered_custom_timetable(importable_string) # This raises if not found. + return importable_string + + +def encode_timetable(var: BaseTimetable | CoreTimetable) -> dict[str, Any]: + """ + Encode a timetable instance. + + See ``_TimetableSerializer.serialize()`` for more implementation detail. + + :meta private: + """ + importable_string = _get_serialized_timetable_import_path(var) + return {Encoding.TYPE: importable_string, Encoding.VAR: _serializer.serialize(var)} + + +class _TimetableSerializer: + """Timetable serialization logic.""" + + BUILTIN_TIMETABLES: dict[type, str] = { + AssetOrTimeSchedule: "airflow.timetables.assets.AssetOrTimeSchedule", + AssetTriggeredTimetable: "airflow.timetables.simple.AssetTriggeredTimetable", + ContinuousTimetable: "airflow.timetables.simple.ContinuousTimetable", + CronDataIntervalTimetable: "airflow.timetables.interval.CronDataIntervalTimetable", + CronTriggerTimetable: "airflow.timetables.trigger.CronTriggerTimetable", + DeltaDataIntervalTimetable: "airflow.timetables.interval.DeltaDataIntervalTimetable", + DeltaTriggerTimetable: "airflow.timetables.trigger.DeltaTriggerTimetable", + EventsTimetable: "airflow.timetables.events.EventsTimetable", + MultipleCronTriggerTimetable: "airflow.timetables.trigger.MultipleCronTriggerTimetable", + NullTimetable: "airflow.timetables.simple.NullTimetable", + OnceTimetable: "airflow.timetables.simple.OnceTimetable", + } + + @functools.singledispatchmethod + def serialize(self, timetable: BaseTimetable | CoreTimetable) -> dict[str, Any]: + """ + Serialize a timetable into a JSON-compatible dict for storage. + + All timetables defined in the SDK should be handled by registered + single-dispatch variants below. + + This function's body should only be + called on timetables defined in Core (under ``airflow.timetables``), + and user-defined custom timetables registered via plugins, which also + inherit from the Core timetable base class. + + For timetables in Core, serialization work is delegated to the type. + """ + if not isinstance(timetable, CoreTimetable): + raise NotImplementedError(f"can not serialize timetable {type(timetable).__name__}") + return timetable.serialize() + + @serialize.register(ContinuousTimetable) + @serialize.register(NullTimetable) + @serialize.register(OnceTimetable) + def _(self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable) -> dict[str, Any]: + return {} + + @serialize.register + def _(self, timetable: AssetTriggeredTimetable) -> dict[str, Any]: + return {"asset_condition": encode_asset_condition(timetable.asset_condition)} + + @serialize.register + def _(self, timetable: EventsTimetable) -> dict[str, Any]: + return { + "event_dates": [x.isoformat(sep="T") for x in timetable.event_dates], + "restrict_to_events": timetable.restrict_to_events, + "description": timetable.description, + } + + @serialize.register + def _(self, timetable: CronDataIntervalTimetable) -> dict[str, Any]: + return {"expression": timetable.expression, "timezone": encode_timezone(timetable.timezone)} + + @serialize.register + def _(self, timetable: DeltaDataIntervalTimetable) -> dict[str, Any]: + return {"delta": encode_interval(timetable.delta)} + + @serialize.register + def _(self, timetable: CronTriggerTimetable) -> dict[str, Any]: + return { + "expression": timetable.expression, + "timezone": encode_timezone(timetable.timezone), + "interval": encode_interval(timetable.interval), + "run_immediately": encode_run_immediately(timetable.run_immediately), + } + + @serialize.register + def _(self, timetable: DeltaTriggerTimetable) -> dict[str, Any]: + return { + "delta": encode_interval(timetable.delta), + "interval": encode_interval(timetable.interval), + } + + @serialize.register + def _(self, timetable: MultipleCronTriggerTimetable) -> dict[str, Any]: + # All timetables share the same timezone, interval, and run_immediately + # values, so we can just use the first to represent them. + representitive = timetable.timetables[0] + return { + "expressions": [t.expression for t in timetable.timetables], + "timezone": encode_timezone(representitive.timezone), + "interval": encode_interval(representitive.interval), + "run_immediately": encode_run_immediately(representitive.run_immediately), + } + + @serialize.register + def _(self, timetable: AssetOrTimeSchedule) -> dict[str, Any]: + return { + "asset_condition": encode_asset_condition(timetable.asset_condition), + "timetable": encode_timetable(timetable.timetable), + } + + @serialize.register + def _(self, timetable: CoreTimetable) -> dict[str, Any]: + return timetable.serialize() + + +_serializer = _TimetableSerializer() + + +@overload +def coerce_to_core_timetable(obj: BaseTimetable | CoreTimetable) -> CoreTimetable: ... + + +@overload +def coerce_to_core_timetable(obj: T) -> T: ... + + +def coerce_to_core_timetable(obj: object) -> object: + """ + Convert *obj* from an SDK timetable to a Core tiemtable instance if possible. + + :meta private: + """ + if isinstance(obj, CoreTimetable) or not isinstance(obj, BaseTimetable): + return obj + + from airflow.serialization.decoders import decode_timetable + + return decode_timetable(encode_timetable(obj)) diff --git a/airflow-core/src/airflow/serialization/helpers.py b/airflow-core/src/airflow/serialization/helpers.py index 949b3cb9c9f09..e9d90fe61ae9b 100644 --- a/airflow-core/src/airflow/serialization/helpers.py +++ b/airflow-core/src/airflow/serialization/helpers.py @@ -18,12 +18,16 @@ from __future__ import annotations -from typing import Any +import contextlib +from typing import TYPE_CHECKING, Any from airflow._shared.secrets_masker import redact from airflow.configuration import conf from airflow.settings import json +if TYPE_CHECKING: + from airflow.timetables.base import Timetable as CoreTimetable + def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: """ @@ -78,3 +82,33 @@ def translate_tuples_to_lists(obj: Any): f"{rendered[: max_length - 79]!r}... " ) return template_field + + +class TimetableNotRegistered(ValueError): + """When an unregistered timetable is being accessed.""" + + def __init__(self, type_string: str) -> None: + self.type_string = type_string + + def __str__(self) -> str: + return ( + f"Timetable class {self.type_string!r} is not registered or " + "you have a top level database access that disrupted the session. " + "Please check the airflow best practices documentation." + ) + + +def find_registered_custom_timetable(importable_string: str) -> type[CoreTimetable]: + """Find a user-defined custom timetable class registered via a plugin.""" + from airflow import plugins_manager + + plugins_manager.initialize_timetables_plugins() + if plugins_manager.timetable_classes is not None: + with contextlib.suppress(KeyError): + return plugins_manager.timetable_classes[importable_string] + raise TimetableNotRegistered(importable_string) + + +def is_core_timetable_import_path(importable_string: str) -> bool: + """Whether an importable string points to a core timetable class.""" + return importable_string.startswith("airflow.timetables.") diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 6836fcd2d8bf7..258995e420c25 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -75,7 +75,7 @@ from airflow.models.tasklog import LogTemplate from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg -from airflow.sdk import DAG, Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg +from airflow.sdk import DAG, Asset, AssetAlias, AssetAll, AssetAny, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.asset import ( @@ -93,10 +93,22 @@ from airflow.sdk.definitions.xcom_arg import serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency +from airflow.serialization.decoders import ( + decode_asset, + decode_asset_condition, + decode_relativedelta, + decode_timetable, +) from airflow.serialization.definitions.param import SerializedParam, SerializedParamsDict from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup +from airflow.serialization.encoders import ( + encode_asset_condition, + encode_relativedelta, + encode_timetable, + encode_timezone, +) from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding -from airflow.serialization.helpers import serialize_template_field +from airflow.serialization.helpers import TimetableNotRegistered, serialize_template_field from airflow.serialization.json_schema import load_dag_schema from airflow.settings import DAGS_FOLDER, json from airflow.stats import Stats @@ -116,7 +128,6 @@ from airflow.utils.code_utils import get_python_source from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor from airflow.utils.db import LazySelectSequence -from airflow.utils.docs import get_docs_url from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string, qualname from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -138,7 +149,6 @@ from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.simple import PartitionMapper - from airflow.triggers.base import BaseEventTrigger try: from kubernetes.client import models as k8s # noqa: TC004 @@ -163,64 +173,6 @@ log = logging.getLogger(__name__) -def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]: - """Encode a relativedelta object.""" - encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} - if var.weekday and var.weekday.n: - # Every n'th Friday for example - encoded["weekday"] = [var.weekday.weekday, var.weekday.n] - elif var.weekday: - encoded["weekday"] = [var.weekday.weekday] - return encoded - - -def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta: - """Dencode a relativedelta object.""" - if "weekday" in var: - var["weekday"] = relativedelta.weekday(*var["weekday"]) - return relativedelta.relativedelta(**var) - - -def encode_timezone(var: Timezone | FixedTimezone) -> str | int: - """ - Encode a Pendulum Timezone for serialization. - - Airflow only supports timezone objects that implements Pendulum's Timezone - interface. We try to keep as much information as possible to make conversion - round-tripping possible (see ``decode_timezone``). We need to special-case - UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as - 0 without the special case), but passing 0 into ``pendulum.timezone`` does - not give us UTC (but ``+00:00``). - """ - if isinstance(var, FixedTimezone): - if var.offset == 0: - return "UTC" - return var.offset - if isinstance(var, Timezone): - return var.name - raise ValueError( - f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. " - f"See {get_docs_url('timezone.html#time-zone-aware-dags')}" - ) - - -def decode_timezone(var: str | int) -> Timezone | FixedTimezone: - """Decode a previously serialized Pendulum Timezone.""" - return parse_timezone(var) - - -def _get_registered_timetable(importable_string: str) -> type[Timetable] | None: - from airflow import plugins_manager - - if importable_string.startswith("airflow.timetables."): - return import_string(importable_string) - plugins_manager.initialize_timetables_plugins() - if plugins_manager.timetable_classes: - return plugins_manager.timetable_classes.get(importable_string) - else: - return None - - def _get_registered_priority_weight_strategy( importable_string: str, ) -> type[PriorityWeightStrategy] | None: @@ -235,18 +187,6 @@ def _get_registered_priority_weight_strategy( return None -class _TimetableNotRegistered(ValueError): - def __init__(self, type_string: str) -> None: - self.type_string = type_string - - def __str__(self) -> str: - return ( - f"Timetable class {self.type_string!r} is not registered or " - "you have a top level database access that disrupted the session. " - "Please check the airflow best practices documentation." - ) - - class _PartitionMapperNotFound(ValueError): def __init__(self, type_string: str) -> None: self.type_string = type_string @@ -271,126 +211,6 @@ def __str__(self) -> str: ) -def _encode_trigger(trigger: BaseEventTrigger | dict): - def _ensure_serialized(d): - """ - Make sure the kwargs dict is JSON-serializable. - - This is done with BaseSerialization logic. A simple check is added to - ensure we don't double-serialize, which is possible when a trigger goes - through multiple serialization layers. - """ - if isinstance(d, dict) and Encoding.TYPE in d: - return d - return BaseSerialization.serialize(d) - - if isinstance(trigger, dict): - classpath = trigger["classpath"] - kwargs = trigger["kwargs"] - else: - classpath, kwargs = trigger.serialize() - return { - "classpath": classpath, - "kwargs": {k: _ensure_serialized(v) for k, v in kwargs.items()}, - } - - -def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: - """ - Encode an asset condition. - - :meta private: - """ - if isinstance(var, Asset): - - def _encode_watcher(watcher: AssetWatcher): - return { - "name": watcher.name, - "trigger": _encode_trigger(watcher.trigger), - } - - asset = { - "__type": DAT.ASSET, - "name": var.name, - "uri": var.uri, - "group": var.group, - "extra": var.extra, - } - - if len(var.watchers) > 0: - asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers] - - return asset - if isinstance(var, AssetAlias): - return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group} - if isinstance(var, AssetAll): - return { - "__type": DAT.ASSET_ALL, - "objects": [encode_asset_condition(x) for x in var.objects], - } - if isinstance(var, AssetAny): - return { - "__type": DAT.ASSET_ANY, - "objects": [encode_asset_condition(x) for x in var.objects], - } - if isinstance(var, AssetRef): - return {"__type": DAT.ASSET_REF, **attrs.asdict(var)} - raise ValueError(f"serialization not implemented for {type(var).__name__!r}") - - -def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: - """ - Decode a previously serialized asset condition. - - :meta private: - """ - dat = var["__type"] - if dat == DAT.ASSET: - return decode_asset(var) - if dat == DAT.ASSET_ALL: - return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) - if dat == DAT.ASSET_ANY: - return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) - if dat == DAT.ASSET_ALIAS: - return AssetAlias(name=var["name"], group=var["group"]) - if dat == DAT.ASSET_REF: - return Asset.ref(**{k: v for k, v in var.items() if k != "__type"}) - raise ValueError(f"deserialization not implemented for DAT {dat!r}") - - -def smart_decode_trigger_kwargs(d): - """ - Slightly clean up kwargs for display or execution. - - This detects one level of BaseSerialization and tries to deserialize the - content, removing some __type __var ugliness when the value is displayed - in UI to the user and/or while execution. - """ - if not isinstance(d, dict) or Encoding.TYPE not in d: - return d - return BaseSerialization.deserialize(d) - - -def decode_asset(var: dict[str, Any]): - watchers = var.get("watchers", []) - return Asset( - name=var["name"], - uri=var["uri"], - group=var["group"], - extra=var["extra"], - watchers=[ - SerializedAssetWatcher( - name=watcher["name"], - trigger={ - "classpath": watcher["trigger"]["classpath"], - "kwargs": smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]), - }, - ) - for watcher in watchers - ], - ) - - def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: key = var.key return { @@ -440,38 +260,6 @@ def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors: return d -def encode_timetable(var: Timetable) -> dict[str, Any]: - """ - Encode a timetable instance. - - This delegates most of the serialization work to the type, so the behavior - can be completely controlled by a custom subclass. - - :meta private: - """ - timetable_class = type(var) - importable_string = qualname(timetable_class) - if _get_registered_timetable(importable_string) is None: - raise _TimetableNotRegistered(importable_string) - return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()} - - -def decode_timetable(var: dict[str, Any]) -> Timetable: - """ - Decode a previously serialized timetable. - - Most of the deserialization logic is delegated to the actual type, which - we import from string. - - :meta private: - """ - importable_string = var[Encoding.TYPE] - timetable_class = _get_registered_timetable(importable_string) - if timetable_class is None: - raise _TimetableNotRegistered(importable_string) - return timetable_class.deserialize(var[Encoding.VAR]) - - def _load_partition_mapper(importable_string) -> PartitionMapper | None: if importable_string.startswith("airflow.timetables."): return import_string(importable_string) @@ -991,7 +779,7 @@ def deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.TIMEDELTA: return datetime.timedelta(seconds=var) elif type_ == DAT.TIMEZONE: - return decode_timezone(var) + return parse_timezone(var) elif type_ == DAT.RELATIVEDELTA: return decode_relativedelta(var) elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER: @@ -2644,7 +2432,7 @@ def serialize_dag(cls, dag: DAG) -> dict: except SerializationError: raise except Exception as e: - raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}") + raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}") from e @classmethod def deserialize_dag( @@ -2661,7 +2449,7 @@ def deserialize_dag( try: return cls._deserialize_dag_internal(encoded_dag, client_defaults) - except (_TimetableNotRegistered, DeserializationError): + except (TimetableNotRegistered, DeserializationError): # Let specific errors bubble up unchanged raise except Exception as err: @@ -2879,7 +2667,7 @@ def _create_compat_timetable(value): from airflow.sdk.definitions.dag import _create_timetable if tzs := dag_dict.get("timezone"): - timezone = decode_timezone(tzs) + timezone = parse_timezone(tzs) else: timezone = settings.TIMEZONE timetable = _create_timetable(value, timezone) @@ -3044,10 +2832,6 @@ def roots(self) -> list[SerializedOperator]: def owner(self) -> str: return ", ".join({t.owner for t in self.tasks}) - @property - def timetable_summary(self) -> str: - return self.timetable.summary - def has_task(self, task_id: str) -> bool: return task_id in self.task_dict @@ -3962,12 +3746,6 @@ def set_ref(task: SerializedOperator) -> SerializedOperator: return group -class SerializedAssetWatcher(AssetWatcher): - """JSON serializable representation of an asset watcher.""" - - trigger: dict - - @cache def _has_kubernetes(attempt_import: bool = False) -> bool: """ diff --git a/airflow-core/src/airflow/timetables/_cron.py b/airflow-core/src/airflow/timetables/_cron.py index b8bc22921ba86..f48593dea4ad4 100644 --- a/airflow-core/src/airflow/timetables/_cron.py +++ b/airflow-core/src/airflow/timetables/_cron.py @@ -124,7 +124,9 @@ def __eq__(self, other: object) -> bool: This is only for testing purposes and should not be relied on otherwise. """ - if not isinstance(other, type(self)): + from airflow.serialization.encoders import coerce_to_core_timetable + + if not isinstance(other := coerce_to_core_timetable(other), type(self)): return NotImplemented return self._expression == other._expression and self._timezone == other._timezone diff --git a/airflow-core/src/airflow/timetables/_delta.py b/airflow-core/src/airflow/timetables/_delta.py index acdf6aa704ec3..891f5da3e9ceb 100644 --- a/airflow-core/src/airflow/timetables/_delta.py +++ b/airflow-core/src/airflow/timetables/_delta.py @@ -34,6 +34,21 @@ class DeltaMixin: def __init__(self, delta: datetime.timedelta | relativedelta) -> None: self._delta = delta + def __eq__(self, other: object) -> bool: + """ + Return if the offsets match. + + This is only for testing purposes and should not be relied on otherwise. + """ + from airflow.serialization.encoders import coerce_to_core_timetable + + if not isinstance(other := coerce_to_core_timetable(other), type(self)): + return NotImplemented + return self._delta == other._delta + + def __hash__(self): + return hash(self._delta) + @property def summary(self) -> str: return str(self._delta) diff --git a/airflow-core/src/airflow/timetables/assets.py b/airflow-core/src/airflow/timetables/assets.py index 6d23313243821..98bb2cb9589e7 100644 --- a/airflow-core/src/airflow/timetables/assets.py +++ b/airflow-core/src/airflow/timetables/assets.py @@ -20,7 +20,7 @@ import typing from airflow.exceptions import AirflowTimetableInvalid -from airflow.sdk.definitions.asset import AssetAll, BaseAsset +from airflow.sdk.definitions.asset import AssetAll, BaseAsset # TODO: Use serialized classes. from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType @@ -55,27 +55,27 @@ def __init__( @classmethod def deserialize(cls, data: dict[str, typing.Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_asset_condition, decode_timetable + from airflow.serialization.decoders import decode_asset_condition, decode_timetable return cls( assets=decode_asset_condition(data["asset_condition"]), timetable=decode_timetable(data["timetable"]), ) + def validate(self) -> None: + if isinstance(self.timetable, AssetTriggeredTimetable): + raise AirflowTimetableInvalid("cannot nest asset timetables") + if not isinstance(self.asset_condition, BaseAsset): + raise AirflowTimetableInvalid("all elements in 'assets' must be assets") + def serialize(self) -> dict[str, typing.Any]: - from airflow.serialization.serialized_objects import encode_asset_condition, encode_timetable + from airflow.serialization.encoders import encode_asset_condition, encode_timetable return { "asset_condition": encode_asset_condition(self.asset_condition), "timetable": encode_timetable(self.timetable), } - def validate(self) -> None: - if isinstance(self.timetable, AssetTriggeredTimetable): - raise AirflowTimetableInvalid("cannot nest asset timetables") - if not isinstance(self.asset_condition, BaseAsset): - raise AirflowTimetableInvalid("all elements in 'assets' must be assets") - @property def summary(self) -> str: return f"Asset or {self.timetable.summary}" diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index 47f7344b32b8c..33ebc316aa9af 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -16,58 +16,18 @@ # under the License. from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable -from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset +from airflow.sdk.bases.timetable import NullAsset # TODO: Separate asset definitions. if TYPE_CHECKING: from pendulum import DateTime - from sqlalchemy.orm import Session - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef - from airflow.serialization.dag_dependency import DagDependency + from airflow.sdk.definitions.asset import BaseAsset from airflow.utils.types import DagRunType -class _NullAsset(BaseAsset): - """ - Sentinel type that represents "no assets". - - This is only implemented to make typing easier in timetables, and not - expected to be used anywhere else. - - :meta private: - """ - - def __bool__(self) -> bool: - return False - - def __or__(self, other: BaseAsset) -> BaseAsset: - return NotImplemented - - def __and__(self, other: BaseAsset) -> BaseAsset: - return NotImplemented - - def as_expression(self) -> Any: - return None - - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return False - - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: - return iter(()) - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - return iter(()) - - def iter_asset_refs(self) -> Iterator[AssetRef]: - return iter(()) - - def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]: - return iter(()) - - class DataInterval(NamedTuple): """ A data interval for a DagRun to operate over. @@ -169,6 +129,7 @@ class Timetable(Protocol): like ``schedule=None`` and ``"@once"`` set it to *False*. """ + # TODO (GH-52141): Find a way to keep this and one in Core in sync. can_be_scheduled: bool = True """ Whether this timetable can actually schedule runs in an automated manner. @@ -184,6 +145,7 @@ class Timetable(Protocol): This should be a list of field names on the DAG run object. """ + # TODO (GH-52141): Find a way to keep this and one in Core in sync. active_runs_limit: int | None = None """Maximum active runs that can be active at one time for a DAG. @@ -193,7 +155,7 @@ class Timetable(Protocol): as for :class:`~airflow.timetable.simple.ContinuousTimetable`. """ - asset_condition: BaseAsset = _NullAsset() + asset_condition: BaseAsset = NullAsset() """The asset condition that triggers a DAG using this timetable. If this is not *None*, this should be an asset, or a combination of, that diff --git a/airflow-core/src/airflow/timetables/events.py b/airflow-core/src/airflow/timetables/events.py index d8e70626d409a..9fedc476c1cda 100644 --- a/airflow-core/src/airflow/timetables/events.py +++ b/airflow-core/src/airflow/timetables/events.py @@ -17,8 +17,7 @@ from __future__ import annotations import itertools -from collections.abc import Iterable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pendulum @@ -26,6 +25,8 @@ from airflow.timetables.base import DagRunInfo, DataInterval, Timetable if TYPE_CHECKING: + from collections.abc import Iterable + from pendulum import DateTime from airflow.timetables.base import TimeRestriction @@ -120,7 +121,7 @@ def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: most_recent_event = next(past_events) return DataInterval.exact(most_recent_event) - def serialize(self): + def serialize(self) -> dict[str, Any]: return { "event_dates": [x.isoformat(sep="T") for x in self.event_dates], "restrict_to_events": self.restrict_to_events, @@ -129,12 +130,12 @@ def serialize(self): } @classmethod - def deserialize(cls, data) -> Timetable: - time_table = cls( + def deserialize(cls, data: dict[str, Any]) -> Timetable: + timetable = cls( event_dates=[pendulum.DateTime.fromisoformat(x) for x in data["event_dates"]], restrict_to_events=data["restrict_to_events"], presorted=True, description=data["description"], ) - time_table._summary = data["_summary"] - return time_table + timetable._summary = data["_summary"] + return timetable diff --git a/airflow-core/src/airflow/timetables/interval.py b/airflow-core/src/airflow/timetables/interval.py index c116b2343fbf6..1bc8d1067e368 100644 --- a/airflow-core/src/airflow/timetables/interval.py +++ b/airflow-core/src/airflow/timetables/interval.py @@ -22,7 +22,7 @@ from dateutil.relativedelta import relativedelta from pendulum import DateTime -from airflow._shared.timezones.timezone import coerce_datetime, utcnow +from airflow._shared.timezones.timezone import coerce_datetime, parse_timezone, utcnow from airflow.timetables._cron import CronMixin from airflow.timetables._delta import DeltaMixin from airflow.timetables.base import DagRunInfo, DataInterval, Timetable @@ -132,12 +132,10 @@ class CronDataIntervalTimetable(CronMixin, _DataIntervalTimetable): @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_timezone - - return cls(data["expression"], decode_timezone(data["timezone"])) + return cls(data["expression"], parse_timezone(data["timezone"])) def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_timezone + from airflow.serialization.encoders import encode_timezone return {"expression": self._expression, "timezone": encode_timezone(self._timezone)} @@ -184,28 +182,15 @@ class DeltaDataIntervalTimetable(DeltaMixin, _DataIntervalTimetable): @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_relativedelta + from airflow.serialization.decoders import decode_relativedelta delta = data["delta"] if isinstance(delta, dict): return cls(decode_relativedelta(delta)) return cls(datetime.timedelta(seconds=delta)) - def __eq__(self, other: object) -> bool: - """ - Return if the offsets match. - - This is only for testing purposes and should not be relied on otherwise. - """ - if not isinstance(other, DeltaDataIntervalTimetable): - return NotImplemented - return self._delta == other._delta - - def __hash__(self): - return hash(self._delta) - def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_relativedelta + from airflow.serialization.encoders import encode_relativedelta delta: Any if isinstance(self._delta, datetime.timedelta): diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index b9a0b27eb7206..a010802945817 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -47,16 +47,15 @@ def __eq__(self, other: object) -> bool: This is only for testing purposes and should not be relied on otherwise. """ - if not isinstance(other, type(self)): + from airflow.serialization.encoders import coerce_to_core_timetable + + if not isinstance(other := coerce_to_core_timetable(other), type(self)): return NotImplemented return True def __hash__(self): return hash(self.__class__.__name__) - def serialize(self) -> dict[str, Any]: - return {} - def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: return DataInterval.exact(run_after) @@ -68,7 +67,7 @@ class NullTimetable(_TrivialTimetable): This corresponds to ``schedule=None``. """ - can_be_scheduled = False + can_be_scheduled = False # TODO (GH-52141): Find a way to keep this and one in Core in sync. description: str = "Never, external triggers only" @property @@ -124,6 +123,7 @@ class ContinuousTimetable(_TrivialTimetable): description: str = "As frequently as possible, but only one run at a time." + # TODO (GH-52141): Find a way to keep this and one in Core in sync. active_runs_limit = 1 # Continuous DAGRuns should be constrained to one run at a time @property @@ -177,7 +177,7 @@ def __init__(self, assets: BaseAsset) -> None: @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_asset_condition + from airflow.serialization.decoders import decode_asset_condition return cls(decode_asset_condition(data["asset_condition"])) @@ -186,7 +186,7 @@ def summary(self) -> str: return "Asset" def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_asset_condition + from airflow.serialization.encoders import encode_asset_condition return {"asset_condition": encode_asset_condition(self.asset_condition)} diff --git a/airflow-core/src/airflow/timetables/trigger.py b/airflow-core/src/airflow/timetables/trigger.py index ed2f65cf24511..f0e04da042e6f 100644 --- a/airflow-core/src/airflow/timetables/trigger.py +++ b/airflow-core/src/airflow/timetables/trigger.py @@ -23,7 +23,7 @@ import time from typing import TYPE_CHECKING, Any -from airflow._shared.timezones.timezone import coerce_datetime, utcnow +from airflow._shared.timezones.timezone import coerce_datetime, parse_timezone, utcnow from airflow.timetables._cron import CronMixin from airflow.timetables._delta import DeltaMixin from airflow.timetables.base import DagRunInfo, DataInterval, Timetable @@ -36,34 +36,6 @@ from airflow.timetables.base import TimeRestriction -def _serialize_interval(interval: datetime.timedelta | relativedelta) -> float | dict: - from airflow.serialization.serialized_objects import encode_relativedelta - - if isinstance(interval, datetime.timedelta): - return interval.total_seconds() - return encode_relativedelta(interval) - - -def _deserialize_interval(value: int | dict) -> datetime.timedelta | relativedelta: - from airflow.serialization.serialized_objects import decode_relativedelta - - if isinstance(value, dict): - return decode_relativedelta(value) - return datetime.timedelta(seconds=value) - - -def _serialize_run_immediately(value: bool | datetime.timedelta) -> bool | float: - if isinstance(value, datetime.timedelta): - return value.total_seconds() - return value - - -def _deserialize_run_immediately(value: bool | float) -> bool | datetime.timedelta: - if isinstance(value, float): - return datetime.timedelta(seconds=value) - return value - - class _TriggerTimetable(Timetable): _interval: datetime.timedelta | relativedelta @@ -150,15 +122,19 @@ def __init__( @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: + from airflow.serialization.decoders import decode_interval + return cls( - _deserialize_interval(data["delta"]), - interval=_deserialize_interval(data["interval"]), + decode_interval(data["delta"]), + interval=decode_interval(data["interval"]), ) def serialize(self) -> dict[str, Any]: + from airflow.serialization.encoders import encode_interval + return { - "delta": _serialize_interval(self._delta), - "interval": _serialize_interval(self._interval), + "delta": encode_interval(self._delta), + "interval": encode_interval(self._interval), } def _calc_first_run(self) -> DateTime: @@ -211,23 +187,23 @@ def __init__( @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_timezone + from airflow.serialization.decoders import decode_interval, decode_run_immediately return cls( data["expression"], - timezone=decode_timezone(data["timezone"]), - interval=_deserialize_interval(data["interval"]), - run_immediately=_deserialize_run_immediately(data.get("run_immediately", False)), + timezone=parse_timezone(data["timezone"]), + interval=decode_interval(data["interval"]), + run_immediately=decode_run_immediately(data.get("run_immediately", False)), ) def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_timezone + from airflow.serialization.encoders import encode_interval, encode_run_immediately, encode_timezone return { "expression": self._expression, "timezone": encode_timezone(self._timezone), - "interval": _serialize_interval(self._interval), - "run_immediately": _serialize_run_immediately(self._run_immediately), + "interval": encode_interval(self._interval), + "run_immediately": encode_run_immediately(self._run_immediately), } def _calc_first_run(self) -> DateTime: @@ -282,17 +258,17 @@ def __init__( @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_timezone + from airflow.serialization.decoders import decode_interval, decode_run_immediately return cls( *data["expressions"], - timezone=decode_timezone(data["timezone"]), - interval=_deserialize_interval(data["interval"]), - run_immediately=_deserialize_run_immediately(data["run_immediately"]), + timezone=parse_timezone(data["timezone"]), + interval=decode_interval(data["interval"]), + run_immediately=decode_run_immediately(data["run_immediately"]), ) def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_timezone + from airflow.serialization.encoders import encode_interval, encode_run_immediately, encode_timezone # All timetables share the same timezone, interval, and run_immediately # values, so we can just use the first to represent them. @@ -300,8 +276,8 @@ def serialize(self) -> dict[str, Any]: return { "expressions": [t._expression for t in self._timetables], "timezone": encode_timezone(timetable._timezone), - "interval": _serialize_interval(timetable._interval), - "run_immediately": _serialize_run_immediately(timetable._run_immediately), + "interval": encode_interval(timetable._interval), + "run_immediately": encode_run_immediately(timetable._run_immediately), } @property diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index af44e55930b50..82c8d2b343a19 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -370,7 +370,6 @@ def test_dagbag_dag_col(self, session): dagbag = DBDagBag() dag_details = dag_command._get_dagbag_dag_details( dagbag.get_latest_version_of_dag("tutorial_dag", session=session), - session=session, ) assert sorted(dag_details) == sorted(dag_command.DAG_DETAIL_FIELDS) diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index e2677033ad3c1..7108a968ebbef 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -71,6 +71,7 @@ from airflow.sdk.definitions.callback import AsyncCallback from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference from airflow.sdk.definitions.param import Param +from airflow.serialization.encoders import coerce_to_core_timetable from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.task.trigger_rule import TriggerRule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable @@ -80,6 +81,7 @@ OnceTimetable, ) from airflow.utils.file import list_py_file_paths +from airflow.utils.module_loading import qualname from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -1011,9 +1013,9 @@ def test_schedule_dag_once(self, testing_dag_bundle): it is called, and not scheduled the second. """ dag_id = "test_schedule_dag_once" - dag = DAG(dag_id=dag_id, schedule="@once", start_date=TEST_DATE) - assert isinstance(dag.timetable, OnceTimetable) - dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) + with DAG(dag_id=dag_id, schedule="@once", start_date=TEST_DATE) as dag: + BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE) + assert qualname(dag.timetable) == "airflow.sdk.definitions.timetables.simple.OnceTimetable" _create_dagrun( dag, @@ -1139,7 +1141,7 @@ def test_timetable_and_description_from_schedule_arg( ): dag = DAG("test_schedule_arg", schedule=schedule_arg, start_date=TEST_DATE) assert dag.timetable == expected_timetable - assert dag.timetable.description == interval_description + assert coerce_to_core_timetable(dag.timetable).description == interval_description def test_timetable_and_description_from_asset(self): uri = "test://asset" @@ -1147,7 +1149,7 @@ def test_timetable_and_description_from_asset(self): "test_schedule_interval_arg", schedule=[Asset(uri=uri, group="test-group")], start_date=TEST_DATE ) assert dag.timetable == AssetTriggeredTimetable(Asset(uri=uri, group="test-group")) - assert dag.timetable.description == "Triggered by assets" + assert coerce_to_core_timetable(dag.timetable).description == "Triggered by assets" @pytest.mark.parametrize( ("timetable", "expected_description"), @@ -1172,7 +1174,7 @@ def test_timetable_and_description_from_asset(self): def test_description_from_timetable(self, timetable, expected_description): dag = DAG("test_schedule_description", schedule=timetable, start_date=TEST_DATE) assert dag.timetable == timetable - assert dag.timetable.description == expected_description + assert coerce_to_core_timetable(dag.timetable).description == expected_description def test_create_dagrun_job_id_is_set(self, testing_dag_bundle): job_id = 42 @@ -1621,8 +1623,8 @@ class FailingTimetable(Timetable): def next_dagrun_info(self, last_automated_data_interval, restriction): raise RuntimeError("this fails") - def _get_registered_timetable(s): - if s == "unit.models.test_dag.FailingTimetable": + def _find_registered_custom_timetable(s): + if s == qualname(FailingTimetable): return FailingTimetable raise ValueError(f"unexpected class {s!r}") @@ -1632,9 +1634,15 @@ def _get_registered_timetable(s): schedule=FailingTimetable(), catchup=True, ) - with mock.patch( - "airflow.serialization.serialized_objects._get_registered_timetable", - _get_registered_timetable, + with ( + mock.patch( + "airflow.serialization.encoders.find_registered_custom_timetable", + _find_registered_custom_timetable, + ), + mock.patch( + "airflow.serialization.decoders.find_registered_custom_timetable", + _find_registered_custom_timetable, + ), ): scheduler_dag = create_scheduler_dag(dag) @@ -2728,9 +2736,15 @@ def _get_registered_timetable(s): start_date=DEFAULT_DATE, schedule=FailingAfterOneTimetable(), ) - with mock.patch( - "airflow.serialization.serialized_objects._get_registered_timetable", - _get_registered_timetable, + with ( + mock.patch( + "airflow.serialization.decoders.find_registered_custom_timetable", + _get_registered_timetable, + ), + mock.patch( + "airflow.serialization.encoders.find_registered_custom_timetable", + _get_registered_timetable, + ), ): scheduler_dag = create_scheduler_dag(dag) @@ -2786,7 +2800,8 @@ def test_get_next_data_interval( next_dagrun_data_interval_end=data_interval_end, ) - assert get_next_data_interval(dag.timetable, dag_model) == expected_data_interval + core_timetable = coerce_to_core_timetable(dag.timetable) + assert get_next_data_interval(core_timetable, dag_model) == expected_data_interval @pytest.mark.need_serialized_dag @@ -2847,7 +2862,6 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, catchup, restrict): assert dag._time_restriction == restrict -@pytest.mark.need_serialized_dag def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): asset1 = Asset(uri="test://asset1", name="test_asset1", group="test-group") asset2 = Asset(uri="test://asset2", group="test-group") @@ -3489,9 +3503,11 @@ def test_get_run_data_interval(): data_interval=(DEFAULT_DATE, DEFAULT_DATE), run_type=DagRunType.MANUAL, ) - assert get_run_data_interval(dag.timetable, dr) == DataInterval(start=DEFAULT_DATE, end=DEFAULT_DATE) + timetable = coerce_to_core_timetable(dag.timetable) + assert get_run_data_interval(timetable, dr) == DataInterval(start=DEFAULT_DATE, end=DEFAULT_DATE) +@pytest.mark.need_serialized_dag def test_get_run_data_interval_pre_aip_39(): with DAG( "dag", @@ -3509,4 +3525,5 @@ def test_get_run_data_interval_pre_aip_39(): ) ds_start = current_ts.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1) ds_end = current_ts.replace(hour=0, minute=0, second=0, microsecond=0) - assert get_run_data_interval(dag.timetable, dr) == DataInterval(start=ds_start, end=ds_end) + timetable = coerce_to_core_timetable(dag.timetable) + assert get_run_data_interval(timetable, dr) == DataInterval(start=ds_start, end=ds_end) diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index fb261921aef4a..bdcf97b0690f8 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -701,6 +701,8 @@ def validate_deserialized_dag(self, serialized_dag: SerializedDAG, dag: DAG): Verify that all example DAGs work with DAG Serialization by checking fields between Serialized Dags & non-Serialized Dags """ + from airflow.serialization.encoders import _serializer + exclusion_list = { # Doesn't implement __eq__ properly. Check manually. "timetable", @@ -733,9 +735,10 @@ def validate_deserialized_dag(self, serialized_dag: SerializedDAG, dag: DAG): f"{dag.dag_id}.default_args[{k}] does not match" ) - assert serialized_dag.timetable.summary == dag.timetable.summary - assert serialized_dag.timetable.serialize() == dag.timetable.serialize() - assert serialized_dag.timezone == dag.timezone + if (tt_type := type(dag.timetable)) in _serializer.BUILTIN_TIMETABLES: + assert _serializer.BUILTIN_TIMETABLES[tt_type] == qualname(serialized_dag.timetable) + else: + assert qualname(dag.timetable) == qualname(serialized_dag.timetable) for task_id in dag.task_ids: self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id)) @@ -1051,7 +1054,7 @@ def test_deserialization_timetable_summary( } SerializedDAG.validate_schema(serialized) dag = SerializedDAG.from_dict(serialized) - assert dag.timetable_summary == expected_timetable_summary + assert dag.timetable.summary == expected_timetable_summary def test_deserialization_timetable_unregistered(self): serialized = { diff --git a/airflow-core/tests/unit/timetables/test_assets_timetable.py b/airflow-core/tests/unit/timetables/test_assets_timetable.py index 8386f26338ff5..3026cace6327b 100644 --- a/airflow-core/tests/unit/timetables/test_assets_timetable.py +++ b/airflow-core/tests/unit/timetables/test_assets_timetable.py @@ -128,7 +128,7 @@ def test_serialization(asset_timetable: AssetOrTimeSchedule, monkeypatch: Any) - :param monkeypatch: The monkeypatch fixture from pytest. """ monkeypatch.setattr( - "airflow.serialization.serialized_objects.encode_timetable", lambda x: "mock_serialized_timetable" + "airflow.serialization.encoders.encode_timetable", lambda x: "mock_serialized_timetable" ) serialized = asset_timetable.serialize() assert serialized == { @@ -154,9 +154,7 @@ def test_deserialization(monkeypatch: Any) -> None: :param monkeypatch: The monkeypatch fixture from pytest. """ - monkeypatch.setattr( - "airflow.serialization.serialized_objects.decode_timetable", lambda x: MockTimetable() - ) + monkeypatch.setattr("airflow.serialization.decoders.decode_timetable", lambda x: MockTimetable()) mock_serialized_data = { "timetable": "mock_serialized_timetable", "asset_condition": { diff --git a/airflow-core/tests/unit/timetables/test_events_timetable.py b/airflow-core/tests/unit/timetables/test_events_timetable.py index 9d42d5ddc6c0d..f4df62e39e164 100644 --- a/airflow-core/tests/unit/timetables/test_events_timetable.py +++ b/airflow-core/tests/unit/timetables/test_events_timetable.py @@ -212,8 +212,8 @@ def test_timetable_after_serialization_is_the_same(): def test_timetable_without_description_after_serialization_is_the_same(): timetable = EventsTimetable(event_dates=EVENT_DATES, presorted=True) - summary = f"{timetable.summary}" - description = f"{timetable.description}" + summary = timetable.summary + description = timetable.description assert timetable.event_dates == EVENT_DATES deserialized: EventsTimetable = timetable.deserialize(timetable.serialize()) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 27826b6604a83..cd83af125a1c9 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2397,6 +2397,13 @@ def execute(self, context): from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails from airflow.timetables.base import TimeRestriction + from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS + + if AIRFLOW_V_3_2_PLUS: + from airflow.serialization.encoders import coerce_to_core_timetable + else: + coerce_to_core_timetable = lambda t: t + timezone = _import_timezone() def _create_task_instance( @@ -2445,20 +2452,20 @@ def _create_task_instance( data_interval_start = None data_interval_end = None - if task.dag.timetable: - if run_type == DagRunType.MANUAL: - if logical_date is not None: - data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( - run_after=logical_date, - ) - else: - drinfo = task.dag.timetable.next_dagrun_info( - last_automated_data_interval=None, - restriction=TimeRestriction(earliest=None, latest=None, catchup=False), + timetable = coerce_to_core_timetable(task.dag.timetable) + if run_type == DagRunType.MANUAL: + if logical_date is not None: + data_interval_start, data_interval_end = timetable.infer_manual_data_interval( + run_after=logical_date, ) - if drinfo: - data_interval = drinfo.data_interval - data_interval_start, data_interval_end = data_interval.start, data_interval.end + else: + drinfo = timetable.next_dagrun_info( + last_automated_data_interval=None, + restriction=TimeRestriction(earliest=None, latest=None, catchup=False), + ) + if drinfo: + data_interval = drinfo.data_interval + data_interval_start, data_interval_end = data_interval.start, data_interval.end dag_id = task.dag.dag_id task_retries = task.retries or 0 diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 62604042e7437..3a805f7915616 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -662,18 +662,29 @@ class DagInfo(InfoJsonEncodable): renames = {"_dag_id": "dag_id"} @classmethod - def timetable_summary(cls, dag: DAG) -> str | None: + def timetable_summary(cls, dag: DAG | SerializedDAG) -> str | None: """Extract summary from timetable if missing a ``timetable_summary`` property.""" - if getattr(dag, "timetable_summary", None): - return dag.timetable_summary - if getattr(dag, "timetable", None): - return dag.timetable.summary + if summary := getattr(dag, "timetable_summary", None): + return summary + if (timetable := getattr(dag, "timetable", None)) is None: + return None + if summary := getattr(timetable, "summary", None): + return summary + with suppress(ImportError): + from airflow.serialization.encoders import coerce_to_core_timetable + + return coerce_to_core_timetable(timetable).summary return None @classmethod - def serialize_timetable(cls, dag: DAG) -> dict[str, Any]: + def serialize_timetable(cls, dag: DAG | SerializedDAG) -> dict[str, Any]: # This is enough for Airflow 2.10+ and has all the information needed - serialized = dag.timetable.serialize() or {} + try: + serialized = dag.timetable.serialize() or {} # type: ignore[union-attr] + except AttributeError: + from airflow.serialization.encoders import encode_timetable + + serialized = encode_timetable(dag.timetable)["__var"] # In Airflow 2.9 when using Dataset scheduling we do not receive datasets in serialized timetable # Also for DatasetOrTimeSchedule, we only receive timetable without dataset_condition @@ -914,7 +925,7 @@ def get_airflow_debug_facet() -> dict[str, AirflowDebugRunFacet]: def get_airflow_run_facet( dag_run: DagRun, - dag: DAG, + dag: DAG | SerializedDAG, task_instance: TaskInstance, task: BaseOperator, task_uuid: str, diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index 9cf75573dcfc2..087a8607d45c5 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -26,7 +26,7 @@ import pendulum from airflow.providers.standard.operators.branch import BaseBranchOperator -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.utils.types import DagRunType if TYPE_CHECKING: @@ -35,6 +35,17 @@ from airflow.models import DagRun from airflow.providers.common.compat.sdk import Context +if AIRFLOW_V_3_2_PLUS: + + def _get_dag_timetable(dag): + from airflow.serialization.encoders import coerce_to_core_timetable + + return coerce_to_core_timetable(dag.timetable) +else: + + def _get_dag_timetable(dag): + return dag.timetable + class LatestOnlyOperator(BaseBranchOperator): """ @@ -104,15 +115,13 @@ def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | Non else: end = dagrun_date - current_interval = DataInterval( - start=start, - end=end, - ) - + timetable = _get_dag_timetable(self.dag) + current_interval = DataInterval(start=start, end=end) time_restriction = TimeRestriction( earliest=None, latest=current_interval.end - timedelta(microseconds=1), catchup=True ) - if prev_info := self.dag.timetable.next_dagrun_info( + + if prev_info := timetable.next_dagrun_info( last_automated_data_interval=current_interval, restriction=time_restriction, ): @@ -121,7 +130,7 @@ def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | Non left = current_interval.start time_restriction = TimeRestriction(earliest=current_interval.end, latest=None, catchup=True) - next_info = self.dag.timetable.next_dagrun_info( + next_info = timetable.next_dagrun_info( last_automated_data_interval=current_interval, restriction=time_restriction, ) diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index c95c61b6cd369..25e07bb0c3661 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -1737,8 +1737,9 @@ def run_tasks( for dag in dag_bag.dags.values(): data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date)) if AIRFLOW_V_3_0_PLUS: - runs[dag.dag_id] = dagrun = create_scheduler_dag(dag).create_dagrun( - run_id=dag.timetable.generate_run_id( + scheduler_dag = create_scheduler_dag(dag) + runs[dag.dag_id] = dagrun = scheduler_dag.create_dagrun( + run_id=scheduler_dag.timetable.generate_run_id( run_type=DagRunType.MANUAL, run_after=logical_date, data_interval=data_interval, @@ -1754,7 +1755,7 @@ def run_tasks( ) else: runs[dag.dag_id] = dagrun = dag.create_dagrun( # type: ignore[attr-defined,call-arg] - run_id=dag.timetable.generate_run_id( # type: ignore[call-arg] + run_id=dag.timetable.generate_run_id( # type: ignore[attr-defined,call-arg,union-attr] run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval, diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py index c94f05b56bdf9..d6c5cffd134f8 100644 --- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py +++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py @@ -25,7 +25,7 @@ import pendulum import pytest -from airflow.models import DAG, DagBag, TaskInstance +from airflow.models import DagBag, TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.utils.sensor_helper import ( _count_stmt, @@ -36,7 +36,6 @@ from airflow.utils.types import DagRunType from tests_common.test_utils import db -from tests_common.test_utils.dag import create_scheduler_dag from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS try: @@ -51,6 +50,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.serialization.serialized_objects import SerializedDAG TI = TaskInstance @@ -84,7 +84,7 @@ def _clean_db(): @staticmethod def create_dag_run( - dag: DAG, + dag: SerializedDAG, *, task_states: Mapping[str, TaskInstanceState] | None = None, execution_date: datetime.datetime | None = None, @@ -94,7 +94,7 @@ def create_dag_run( execution_date = pendulum.instance(execution_date or now) run_type = DagRunType.MANUAL data_interval = dag.timetable.infer_manual_data_interval(run_after=execution_date) - dag_run = create_scheduler_dag(dag).create_dagrun( + dag_run = dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=run_type, run_after=execution_date, @@ -162,6 +162,7 @@ def create_dag_run( }, ], # these can be any TaskInstanceState ) + @pytest.mark.need_serialized_dag def test_count_stmt(self, dttm_to_task_state, dag_maker, session): with dag_maker(dag_id=self.DAG_ID, session=session) as dag: for task_id in self.TASK_ID_LIST: @@ -193,6 +194,7 @@ def test_count_stmt(self, dttm_to_task_state, dag_maker, session): allowed_state_count = len(allowed_task_instance_states) assert count == allowed_state_count + @pytest.mark.need_serialized_dag def test_get_external_task_group_task_ids(self, dag_maker, session): with dag_maker(dag_id=self.DAG_ID) as dag: with TaskGroup(group_id=self.TASK_GROUP_ID): @@ -225,6 +227,7 @@ def test_get_external_task_group_task_ids(self, dag_maker, session): TaskInstanceState.FAILED.value, ], ) + @pytest.mark.need_serialized_dag def test_get_count_with_different_states(self, state, dag_maker, session): with dag_maker(dag_id=self.DAG_ID) as dag: EmptyOperator(task_id=self.TASK_ID) @@ -258,6 +261,7 @@ def test_get_count_with_different_states(self, state, dag_maker, session): }, ], # these can be any TaskInstanceState ) + @pytest.mark.need_serialized_dag def test_get_count_with_one_task(self, task_states, dag_maker, session): with dag_maker(dag_id=self.DAG_ID) as dag: EmptyOperator(task_id=self.TASK_ID) @@ -313,6 +317,7 @@ def test_get_count_with_one_task(self, task_states, dag_maker, session): }, ], # these can be any TaskInstanceState ) + @pytest.mark.need_serialized_dag def test_get_count_with_multiple_tasks(self, dttm_to_task_state, dag_maker, session): with dag_maker(dag_id=self.DAG_ID) as dag: for task_id in self.TASK_ID_LIST: @@ -376,6 +381,7 @@ def test_get_count_with_multiple_tasks(self, dttm_to_task_state, dag_maker, sess }, ], # these can be any TaskInstanceState ) + @pytest.mark.need_serialized_dag def test_get_count_with_task_group(self, dttm_to_subtask_state, dag_maker, session): with dag_maker(dag_id=self.DAG_ID, session=session) as dag: with TaskGroup(group_id=self.TASK_GROUP_ID): diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py index e9744134360d5..eee5b8daba32d 100755 --- a/scripts/in_container/run_schema_defaults_check.py +++ b/scripts/in_container/run_schema_defaults_check.py @@ -28,6 +28,7 @@ import json import sys +import traceback from datetime import timedelta from pathlib import Path from typing import Any @@ -89,9 +90,11 @@ def get_server_side_operator_defaults() -> dict[str, Any]: except ImportError as e: print(f"Error importing SerializedBaseOperator: {e}") + traceback.print_exc() sys.exit(1) except Exception as e: print(f"Error getting server-side defaults: {e}") + traceback.print_exc() sys.exit(1) @@ -121,9 +124,11 @@ def get_server_side_dag_defaults() -> dict[str, Any]: except ImportError as e: print(f"Error importing SerializedDAG: {e}") + traceback.print_exc() sys.exit(1) except Exception as e: print(f"Error getting server-side DAG defaults: {e}") + traceback.print_exc() sys.exit(1) diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index c689a21c7e07c..5e51c77aceb0e 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -128,6 +128,22 @@ Assets .. autoapiclass:: airflow.sdk.Metadata +Timetables +---------- +.. autoapiclass:: airflow.sdk.AssetOrTimeSchedule + +.. autoapiclass:: airflow.sdk.CronDataIntervalTimetable + +.. autoapiclass:: airflow.sdk.CronTriggerTimetable + +.. autoapiclass:: airflow.sdk.DeltaDataIntervalTimetable + +.. autoapiclass:: airflow.sdk.DeltaTriggerTimetable + +.. autoapiclass:: airflow.sdk.EventsTimetable + +.. autoapiclass:: airflow.sdk.MultipleCronTriggerTimetable + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 9c411354c9518..49b09727c0a8e 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -24,6 +24,7 @@ "AssetAlias", "AssetAll", "AssetAny", + "AssetOrTimeSchedule", "AssetWatcher", "BaseHook", "BaseNotifier", @@ -32,11 +33,17 @@ "BaseSensorOperator", "Connection", "Context", + "CronDataIntervalTimetable", + "CronTriggerTimetable", "DAG", "DagRunState", + "DeltaDataIntervalTimetable", + "DeltaTriggerTimetable", "EdgeModifier", + "EventsTimetable", "Label", "Metadata", + "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", "PokeReturnValue", @@ -81,6 +88,17 @@ from airflow.sdk.definitions.param import Param from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.definitions.template import literal + from airflow.sdk.definitions.timetables.assets import AssetOrTimeSchedule + from airflow.sdk.definitions.timetables.events import EventsTimetable + from airflow.sdk.definitions.timetables.interval import ( + CronDataIntervalTimetable, + DeltaDataIntervalTimetable, + ) + from airflow.sdk.definitions.timetables.trigger import ( + CronTriggerTimetable, + DeltaTriggerTimetable, + MultipleCronTriggerTimetable, + ) from airflow.sdk.definitions.variable import Variable from airflow.sdk.definitions.xcom_arg import XComArg from airflow.sdk.io.path import ObjectStoragePath @@ -90,6 +108,7 @@ "AssetAlias": ".definitions.asset", "AssetAll": ".definitions.asset", "AssetAny": ".definitions.asset", + "AssetOrTimeSchedule": ".definitions.timetables.assets", "AssetWatcher": ".definitions.asset", "BaseHook": ".bases.hook", "BaseNotifier": ".bases.notifier", @@ -98,11 +117,17 @@ "BaseSensorOperator": ".bases.sensor", "Connection": ".definitions.connection", "Context": ".definitions.context", + "CronDataIntervalTimetable": ".definitions.timetables.interval", + "CronTriggerTimetable": ".definitions.timetables.trigger", "DAG": ".definitions.dag", "DagRunState": ".api.datamodels._generated", + "DeltaDataIntervalTimetable": ".definitions.timetables.interval", + "DeltaTriggerTimetable": ".definitions.timetables.trigger", "EdgeModifier": ".definitions.edges", + "EventsTimetable": ".definitions.timetables.events", "Label": ".definitions.edges", "Metadata": ".definitions.asset.metadata", + "MultipleCronTriggerTimetable": ".definitions.timetables.trigger", "ObjectStoragePath": ".io.path", "Param": ".definitions.param", "PokeReturnValue": ".bases.sensor", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index bf7f63b209a6d..ca25ae0fdaaf3 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -56,6 +56,17 @@ from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier, Label as from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup from airflow.sdk.definitions.template import literal as literal +from airflow.sdk.definitions.timetables.assets import AssetOrTimeSchedule +from airflow.sdk.definitions.timetables.events import EventsTimetable +from airflow.sdk.definitions.timetables.interval import ( + CronDataIntervalTimetable, + DeltaDataIntervalTimetable, +) +from airflow.sdk.definitions.timetables.trigger import ( + CronTriggerTimetable, + DeltaTriggerTimetable, + MultipleCronTriggerTimetable, +) from airflow.sdk.definitions.variable import Variable as Variable from airflow.sdk.definitions.xcom_arg import XComArg as XComArg from airflow.sdk.execution_time.cache import SecretCache as SecretCache @@ -67,6 +78,7 @@ __all__ = [ "AssetAlias", "AssetAll", "AssetAny", + "AssetOrTimeSchedule", "AssetWatcher", "BaseHook", "BaseNotifier", @@ -75,11 +87,17 @@ __all__ = [ "BaseSensorOperator", "Connection", "Context", + "CronDataIntervalTimetable", + "CronTriggerTimetable", "DAG", "DagRunState", + "DeltaDataIntervalTimetable", + "DeltaTriggerTimetable", "EdgeModifier", + "EventsTimetable", "Label", "Metadata", + "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", "PokeReturnValue", @@ -88,8 +106,8 @@ __all__ = [ "TaskInstanceState", "TriggerRule", "Variable", - "XComArg", "WeightRule", + "XComArg", "asset", "chain", "chain_linear", diff --git a/task-sdk/src/airflow/sdk/bases/timetable.py b/task-sdk/src/airflow/sdk/bases/timetable.py new file mode 100644 index 0000000000000..20754ed874766 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/timetable.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef + from airflow.serialization.dag_dependency import DagDependency + + +class NullAsset(BaseAsset): + """ + Sentinel type that represents "no assets". + + This is only implemented to make typing easier in timetables, and not + expected to be used anywhere else. + + :meta private: + """ + + def __bool__(self) -> bool: + return False + + def __or__(self, other: BaseAsset) -> BaseAsset: + return NotImplemented + + def __and__(self, other: BaseAsset) -> BaseAsset: + return NotImplemented + + def as_expression(self) -> Any: + return None + + def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: + return False + + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: + return iter(()) + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + return iter(()) + + def iter_asset_refs(self) -> Iterator[AssetRef]: + return iter(()) + + def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]: + return iter(()) + + +class BaseTimetable: + """Base class inherited by all user-facing timetables.""" + + can_be_scheduled: bool = True + """ + Whether this timetable can actually schedule runs in an automated manner. + + This defaults to and should generally be *True* (including non periodic + execution types like *@once* and data triggered tables), but + ``NullTimetable`` sets this to *False*. + """ + + active_runs_limit: int | None = None + """ + Maximum active runs that can be active at one time for a DAG. + + This is called during DAG initialization, and the return value is used as + the DAG's default ``max_active_runs`` if not set on the DAG explicitly. This + should generally return *None* (no limit), but some timetables may limit + parallelism, such as ``ContinuousTimetable``. + """ + + asset_condition: BaseAsset = NullAsset() + + def validate(self) -> None: + """ + Validate the timetable is correctly specified. + + Override this method to provide run-time validation raised when a DAG + is put into a dagbag. The default implementation does nothing. + + :raises: :class:`~airflow.sdk.exceptions.AirflowTimetableInvalid` on validation failure. + """ diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 485584bd8ddb1..26717ea94b7d1 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -39,7 +39,7 @@ from airflow.models.asset import AssetModel from airflow.sdk.io.path import ObjectStoragePath - from airflow.serialization.serialized_objects import SerializedAssetWatcher + from airflow.serialization.definitions.assets import SerializedAssetWatcher from airflow.triggers.base import BaseEventTrigger AttrsInstance = attrs.AttrsInstance diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 3abd205fc3b65..1c363ea9b1fa7 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -30,7 +30,7 @@ from collections.abc import Callable, Collection, Iterable, MutableSet from datetime import datetime, timedelta from inspect import signature -from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, Union, cast, overload from urllib.parse import urlsplit from uuid import UUID @@ -41,12 +41,15 @@ from airflow import settings from airflow.sdk import TaskInstanceState, TriggerRule from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions._internal.node import validate_key from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set from airflow.sdk.definitions.asset import AssetAll, BaseAsset from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.deadline import DeadlineAlert from airflow.sdk.definitions.param import DagParam, ParamsDict +from airflow.sdk.definitions.timetables.assets import AssetTriggeredTimetable +from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable from airflow.sdk.exceptions import ( AirflowDagCycleException, DuplicateTaskIdFound, @@ -55,20 +58,13 @@ RemovedInAirflow4Warning, TaskNotFound, ) -from airflow.timetables.base import Timetable -from airflow.timetables.simple import ( - AssetTriggeredTimetable, - ContinuousTimetable, - NullTimetable, - OnceTimetable, -) if TYPE_CHECKING: from re import Pattern from typing import TypeAlias from pendulum.tz.timezone import FixedTimezone, Timezone - from typing_extensions import Self + from typing_extensions import Self, TypeIs from airflow.models.taskinstance import TaskInstance as SchedulerTaskInstance from airflow.sdk.definitions.decorators import TaskDecoratorCollection @@ -76,6 +72,7 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.execution_time.supervisor import TaskRunResult + from airflow.timetables.base import DataInterval, Timetable as CoreTimetable Operator: TypeAlias = BaseOperator | MappedOperator @@ -101,7 +98,7 @@ DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = None | str | timedelta | relativedelta -ScheduleArg: TypeAlias = ScheduleInterval | Timetable | BaseAsset | Collection[BaseAsset] +ScheduleArg = Union[ScheduleInterval, BaseTimetable, "CoreTimetable", BaseAsset, Collection[BaseAsset]] _DAG_HASH_ATTRS = frozenset( @@ -120,11 +117,22 @@ ) -def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> Timetable: +def _is_core_timetable(schedule: ScheduleArg) -> TypeIs[CoreTimetable]: + try: + from airflow.timetables.base import Timetable + except ImportError: + return False + return isinstance(schedule, Timetable) + + +def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> BaseTimetable: """Create a Timetable instance from a plain ``schedule`` value.""" from airflow.sdk.configuration import conf as airflow_conf - from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable - from airflow.timetables.trigger import CronTriggerTimetable, DeltaTriggerTimetable + from airflow.sdk.definitions.timetables.interval import ( + CronDataIntervalTimetable, + DeltaDataIntervalTimetable, + ) + from airflow.sdk.definitions.timetables.trigger import CronTriggerTimetable, DeltaTriggerTimetable if interval is None: return NullTimetable() @@ -132,7 +140,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime return OnceTimetable() if interval == "@continuous": return ContinuousTimetable() - if isinstance(interval, (timedelta, relativedelta)): + if isinstance(interval, timedelta | relativedelta): if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): return DeltaDataIntervalTimetable(interval) return DeltaTriggerTimetable(interval) @@ -190,7 +198,7 @@ def _convert_access_control(access_control): updated_access_control = {} for role, perms in access_control.items(): updated_access_control[role] = updated_access_control.get(role, {}) - if isinstance(perms, (set, list)): + if isinstance(perms, set | list): # Support for old-style access_control where only the actions are specified updated_access_control[role]["DAGs"] = set(perms) else: @@ -422,7 +430,7 @@ def __rich_repr__(self): end_date: datetime | None = None timezone: FixedTimezone | Timezone = attrs.field(init=False) schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen) - timetable: Timetable = attrs.field(init=False) + timetable: BaseTimetable | CoreTimetable = attrs.field(init=False) template_searchpath: str | Iterable[str] | None = attrs.field( default=None, converter=_convert_str_to_tuple ) @@ -592,11 +600,13 @@ def _validate_max_active_runs(self, _, max_active_runs): ) @timetable.default - def _default_timetable(instance: DAG): + def _default_timetable(instance: DAG) -> BaseTimetable | CoreTimetable: schedule = instance.schedule # TODO: Once # delattr(self, "schedule") - if isinstance(schedule, Timetable): + if _is_core_timetable(schedule): + return schedule + if isinstance(schedule, BaseTimetable): return schedule if isinstance(schedule, BaseAsset): return AssetTriggeredTimetable(schedule) @@ -757,10 +767,6 @@ def owner(self) -> str: """ return ", ".join({t.owner for t in self.tasks}) - @property - def timetable_summary(self) -> str: - return self.timetable.summary - def resolve_template_files(self): for t in self.tasks: # TODO: TaskSDK: move this on to BaseOperator and remove the check? @@ -873,7 +879,7 @@ def partial_subset( from airflow.sdk.definitions.mappedoperator import MappedOperator def is_task(obj) -> TypeGuard[Operator]: - return isinstance(obj, (BaseOperator, MappedOperator)) + return isinstance(obj, BaseOperator | MappedOperator) # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all # the tasks anyway, so we copy the tasks manually later @@ -1179,6 +1185,7 @@ def test( from airflow import settings from airflow.models.dagrun import DagRun, get_or_create_dagrun from airflow.sdk import DagRunState, timezone + from airflow.serialization.encoders import coerce_to_core_timetable from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -1223,9 +1230,11 @@ def test( log.debug("Getting dagrun for dag %s", self.dag_id) logical_date = timezone.coerce_datetime(logical_date) run_after = timezone.coerce_datetime(run_after) or timezone.coerce_datetime(timezone.utcnow()) - data_interval = ( - self.timetable.infer_manual_data_interval(run_after=logical_date) if logical_date else None - ) + if logical_date is None: + data_interval: DataInterval | None = None + else: + timetable = coerce_to_core_timetable(self.timetable) + data_interval = timetable.infer_manual_data_interval(run_after=logical_date) from airflow.models.dag_version import DagVersion version = DagVersion.get_version(self.dag_id) diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/__init__.py b/task-sdk/src/airflow/sdk/definitions/timetables/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/_cron.py b/task-sdk/src/airflow/sdk/definitions/timetables/_cron.py new file mode 100644 index 0000000000000..e0fa7c26966a0 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/_cron.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import attrs +from croniter import CroniterBadCronError, CroniterBadDateError, croniter + +from airflow.sdk.exceptions import AirflowTimetableInvalid + +if TYPE_CHECKING: + from pendulum.tz.timezone import FixedTimezone, Timezone + + +@attrs.define +class CronMixin: + """Mixin to provide interface to work with croniter.""" + + expression: str + timezone: str | Timezone | FixedTimezone + + def validate(self) -> None: + try: + croniter(self.expression) + except (CroniterBadCronError, CroniterBadDateError) as e: + raise AirflowTimetableInvalid(str(e)) diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/_delta.py b/task-sdk/src/airflow/sdk/definitions/timetables/_delta.py new file mode 100644 index 0000000000000..0c12756c57f5b --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/_delta.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING + +import attrs + +from airflow.sdk.exceptions import AirflowTimetableInvalid + +if TYPE_CHECKING: + from dateutil.relativedelta import relativedelta + + +@attrs.define +class DeltaMixin: + """Mixin to provide interface to work with timedelta and relativedelta.""" + + delta: datetime.timedelta | relativedelta + + def validate(self) -> None: + now = datetime.datetime.now() + if (now + self.delta) <= now: + raise AirflowTimetableInvalid(f"schedule interval must be positive, not {self.delta!r}") diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py new file mode 100644 index 0000000000000..810007e8123a9 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import typing + +import attrs + +from airflow.sdk.bases.timetable import BaseTimetable +from airflow.sdk.definitions.asset import AssetAll, BaseAsset + +if typing.TYPE_CHECKING: + from collections.abc import Collection + + from airflow.sdk import Asset + + +@attrs.define +class AssetTriggeredTimetable(BaseTimetable): + """ + Timetable that never schedules anything. + + This should not be directly used anywhere, but only set if a DAG is triggered by assets. + + :meta private: + """ + + asset_condition: BaseAsset = attrs.field(alias="assets") + + +def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset: + if isinstance(o, BaseAsset): + return o + return AssetAll(*o) + + +@attrs.define(kw_only=True) +class AssetOrTimeSchedule(AssetTriggeredTimetable): + """ + Combine time-based scheduling with event-based scheduling. + + :param assets: An asset of list of assets, in the same format as + ``DAG(schedule=...)`` when using event-driven scheduling. This is used + to evaluate event-based scheduling. + :param timetable: A timetable instance to evaluate time-based scheduling. + """ + + asset_condition: BaseAsset = attrs.field(alias="assets", converter=_coerce_assets) + timetable: BaseTimetable + + def __attrs_post_init__(self) -> None: + self.active_runs_limit = self.timetable.active_runs_limit + self.can_be_scheduled = self.timetable.can_be_scheduled diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/events.py b/task-sdk/src/airflow/sdk/definitions/timetables/events.py new file mode 100644 index 0000000000000..382b35516dcf2 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/events.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import attrs + +from airflow.sdk.bases.timetable import BaseTimetable + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pendulum import DateTime + + +@attrs.define(init=False) +class EventsTimetable(BaseTimetable): + """ + Timetable that schedules DAG runs at specific listed datetimes. + + Suitable for predictable but truly irregular scheduling, such as sporting + events, or to schedule against National Holidays. + + :param event_dates: List of datetimes for the DAG to run at. Duplicates + will be ignored. This must be finite and of reasonable size, as it will + be loaded in its entirety. + :param restrict_to_events: Whether manual runs should use the most recent + event or the current time + :param presorted: if True, event_dates will be assumed to be in ascending + order. Provides modest performance improvement for larger lists of + *event_dates*. + :param description: A name for the timetable to display in the UI. If not + provided explicitly (or *None*) the UI will show "X Events" where X is + the length of *event_dates*. + """ + + event_dates: list[DateTime] + restrict_to_events: bool + description: str | None + + def __init__( + self, + event_dates: Iterable[DateTime], + *, + restrict_to_events: bool = False, + presorted: bool = False, + description: str | None = None, + ) -> None: + self.__attrs_init__( # type: ignore[attr-defined] + sorted(event_dates) if presorted else list(event_dates), + restrict_to_events=restrict_to_events, + description=description, + ) diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/interval.py b/task-sdk/src/airflow/sdk/definitions/timetables/interval.py new file mode 100644 index 0000000000000..a04af6379bdc3 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/interval.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime + +from dateutil.relativedelta import relativedelta + +from airflow.sdk.bases.timetable import BaseTimetable +from airflow.sdk.definitions.timetables._cron import CronMixin +from airflow.sdk.definitions.timetables._delta import DeltaMixin + +Delta = datetime.timedelta | relativedelta + + +class CronDataIntervalTimetable(CronMixin, BaseTimetable): + """ + Timetable that schedules data intervals with a cron expression. + + This corresponds to ``schedule=``, where ```` is either + a five/six-segment representation, or one of ``cron_presets``. + + The implementation extends on croniter to add timezone awareness. This is + because croniter works only with naive timestamps, and cannot consider DST + when determining the next/previous time. + + Using this class is equivalent to supplying a cron expression dire + + Don't pass ``@once`` in here; use ``OnceTimetable`` instead. + """ + + +class DeltaDataIntervalTimetable(DeltaMixin, BaseTimetable): + """ + Timetable that schedules data intervals with a time delta. + + This corresponds to ``schedule=``, where ```` is + either a ``datetime.timedelta`` or ``dateutil.relativedelta.relativedelta`` + instance. + """ diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/simple.py b/task-sdk/src/airflow/sdk/definitions/timetables/simple.py new file mode 100644 index 0000000000000..3387cb3c386fd --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/simple.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.bases.timetable import BaseTimetable + + +class NullTimetable(BaseTimetable): + """ + Timetable that never schedules anything. + + This corresponds to ``schedule=None``. + """ + + can_be_scheduled = False + + +class OnceTimetable(BaseTimetable): + """ + Timetable that schedules the execution once as soon as possible. + + This corresponds to ``schedule="@once"``. + """ + + +class ContinuousTimetable(BaseTimetable): + """ + Timetable that schedules continually, while still respecting start_date and end_date. + + This corresponds to ``schedule="@continuous"``. + """ + + active_runs_limit = 1 diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/trigger.py b/task-sdk/src/airflow/sdk/definitions/timetables/trigger.py new file mode 100644 index 0000000000000..66bc3cc245eb0 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/timetables/trigger.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING + +import attrs + +from airflow.sdk.bases.timetable import BaseTimetable +from airflow.sdk.definitions.timetables._cron import CronMixin +from airflow.sdk.definitions.timetables._delta import DeltaMixin + +if TYPE_CHECKING: + from dateutil.relativedelta import relativedelta + from pendulum.tz.timezone import FixedTimezone, Timezone + + +@attrs.define +class DeltaTriggerTimetable(DeltaMixin, BaseTimetable): + """ + Timetable that triggers DAG runs according to a cron expression. + + This is different from ``DeltaDataIntervalTimetable``, where the delta value + specifies the *data interval* of a DAG run. With this timetable, the data + intervals are specified independently. Also for the same reason, this + timetable kicks off a DAG run immediately at the start of the period, + instead of needing to wait for one data interval to pass. + + :param delta: How much time to wait between each run. + :param interval: The data interval of each run. Default is 0. + """ + + interval: datetime.timedelta | relativedelta = attrs.field(kw_only=True, default=datetime.timedelta()) + + +@attrs.define +class CronTriggerTimetable(CronMixin, BaseTimetable): + """ + Timetable that triggers DAG runs according to a cron expression. + + This is different from ``CronDataIntervalTimetable``, where the cron + expression specifies the *data interval* of a DAG run. With this timetable, + the data intervals are specified independently from the cron expression. + Also for the same reason, this timetable kicks off a DAG run immediately at + the start of the period (similar to POSIX cron), instead of needing to wait + for one data interval to pass. + + Don't pass ``@once`` in here; use ``OnceTimetable`` instead. + + :param cron: cron string that defines when to run + :param timezone: Which timezone to use to interpret the cron string + :param interval: timedelta that defines the data interval start. Default 0. + + *run_immediately* controls, if no *start_time* is given to the DAG, when + the first run of the DAG should be scheduled. It has no effect if there + already exist runs for this DAG. + + * If *True*, always run immediately the most recent possible DAG run. + * If *False*, wait to run until the next scheduled time in the future. + * If passed a ``timedelta``, will run the most recent possible DAG run + if that run's ``data_interval_end`` is within timedelta of now. + * If *None*, the timedelta is calculated as 10% of the time between the + most recent past scheduled time and the next scheduled time. E.g. if + running every hour, this would run the previous time if less than 6 + minutes had past since the previous run time, otherwise it would wait + until the next hour. + """ + + interval: datetime.timedelta | relativedelta = attrs.field(kw_only=True, default=datetime.timedelta()) + run_immediately: bool | datetime.timedelta = attrs.field(kw_only=True, default=False) + + +@attrs.define(init=False) +class MultipleCronTriggerTimetable(BaseTimetable): + """ + Timetable that triggers DAG runs according to multiple cron expressions. + + This combines multiple ``CronTriggerTimetable`` instances underneath, and + triggers a DAG run whenever one of the timetables want to trigger a run. + + Only at most one run is triggered for any given time, even if more than one + timetable fires at the same time. + """ + + timetables: list[CronTriggerTimetable] + + def __init__( + self, + *crons: str, + timezone: str | Timezone | FixedTimezone, + interval: datetime.timedelta | relativedelta = datetime.timedelta(), + run_immediately: bool | datetime.timedelta = False, + ) -> None: + if not crons: + raise ValueError("cron expression required") + self.__attrs_init__( # type: ignore[attr-defined] + [ + CronTriggerTimetable(cron, timezone, interval=interval, run_immediately=run_immediately) + for cron in crons + ], + ) diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index f17297f2d923a..8fa76b43fdcfd 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -62,6 +62,10 @@ def __init__(self, error: ErrorResponse): super().__init__(f"{error.error.value}: {error.detail}") +class AirflowTimetableInvalid(AirflowException): + """Raise when a DAG has an invalid timetable.""" + + class ErrorType(enum.Enum): """Error types used in the API client.""" diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index b781bcc9e1490..7aa732e456b56 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -339,7 +339,7 @@ def test_dag_owner_links(self): dag.validate() def test_continuous_schedule_linmits_max_active_runs(self): - from airflow.timetables.simple import ContinuousTimetable + from airflow.sdk.definitions.timetables.simple import ContinuousTimetable dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=1) assert isinstance(dag.timetable, ContinuousTimetable) @@ -471,7 +471,7 @@ def test_create_dag_while_active_context(): @pytest.mark.parametrize("max_active_runs", [0, 1]) def test_continuous_schedule_interval_limits_max_active_runs(max_active_runs): - from airflow.timetables.simple import ContinuousTimetable + from airflow.sdk.definitions.timetables.simple import ContinuousTimetable dag = DAG(dag_id="continuous", schedule="@continuous", max_active_runs=max_active_runs) assert isinstance(dag.timetable, ContinuousTimetable)