diff --git a/airflow-core/src/airflow/decorators/base.py b/airflow-core/src/airflow/decorators/base.py index 00a61dcd2965f..d0c8bdcf6da31 100644 --- a/airflow-core/src/airflow/decorators/base.py +++ b/airflow-core/src/airflow/decorators/base.py @@ -41,9 +41,9 @@ ListOfDictsExpandInput, is_mappable, ) +from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value from airflow.sdk.definitions.xcom_arg import XComArg from airflow.typing_compat import ParamSpec diff --git a/airflow-core/src/airflow/decorators/condition.py b/airflow-core/src/airflow/decorators/condition.py index 06fd01391f28b..a38b2ab24197b 100644 --- a/airflow-core/src/airflow/decorators/condition.py +++ b/airflow-core/src/airflow/decorators/condition.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias - from airflow.sdk.definitions.baseoperator import TaskPreExecuteHook + from airflow.sdk.bases.baseoperator import TaskPreExecuteHook from airflow.sdk.definitions.context import Context BoolConditionFunc: TypeAlias = Callable[[Context], bool] diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 20e1c65df8af7..9274ae7a79f3f 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -26,6 +26,7 @@ "Base", "BaseOperator", "BaseOperatorLink", + "BaseXCom", "Connection", "DagBag", "DagWarning", @@ -44,7 +45,7 @@ "TaskReschedule", "Trigger", "Variable", - "XComModel", + "XCom", "clear_task_instances", ] @@ -65,6 +66,7 @@ def import_all_models(): import airflow.models.serialized_dag import airflow.models.taskinstancehistory import airflow.models.tasklog + import airflow.models.xcom def __getattr__(name): @@ -88,7 +90,8 @@ def __getattr__(name): "ID_LEN": "airflow.models.base", "Base": "airflow.models.base", "BaseOperator": "airflow.models.baseoperator", - "BaseOperatorLink": "airflow.sdk.definitions.baseoperatorlink", + "BaseOperatorLink": "airflow.sdk.bases.operatorlink", + "BaseXCom": "airflow.sdk.bases.xcom", "Connection": "airflow.models.connection", "DagBag": "airflow.models.dagbag", "DagModel": "airflow.models.dag", @@ -114,7 +117,6 @@ def __getattr__(name): if TYPE_CHECKING: # I was unable to get mypy to respect a airflow/models/__init__.pyi, so # having to resort back to this hacky method - from airflow.jobs.job import Job from airflow.models.base import ID_LEN, Base from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection @@ -135,6 +137,7 @@ def __getattr__(name): from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.variable import Variable - from airflow.sdk import BaseOperatorLink + from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.xcom import XCom diff --git a/airflow-core/src/airflow/models/abstractoperator.py b/airflow-core/src/airflow/models/abstractoperator.py index 007e43aa5254c..9dc8f32bdfe39 100644 --- a/airflow-core/src/airflow/models/abstractoperator.py +++ b/airflow-core/src/airflow/models/abstractoperator.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs diff --git a/airflow-core/src/airflow/models/baseoperator.py b/airflow-core/src/airflow/models/baseoperator.py index 51cd83a6fecbd..23181f97d58f7 100644 --- a/airflow-core/src/airflow/models/baseoperator.py +++ b/airflow-core/src/airflow/models/baseoperator.py @@ -47,15 +47,15 @@ NotMapped, ) from airflow.models.taskinstance import TaskInstance, clear_task_instances -from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator -from airflow.sdk.definitions.baseoperator import ( +from airflow.sdk.bases.baseoperator import ( + BaseOperator as TaskSDKBaseOperator, # Re-export for compat chain as chain, chain_linear as chain_linear, cross_downstream as cross_downstream, get_merged_defaults as get_merged_defaults, ) -from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.enums import DagAttributeTypes diff --git a/airflow-core/src/airflow/models/baseoperatorlink.py b/airflow-core/src/airflow/models/baseoperatorlink.py index 3f95e162d8f7b..09d21f868515d 100644 --- a/airflow-core/src/airflow/models/baseoperatorlink.py +++ b/airflow-core/src/airflow/models/baseoperatorlink.py @@ -19,4 +19,4 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink as BaseOperatorLink +from airflow.sdk.bases.operatorlink import BaseOperatorLink as BaseOperatorLink diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 0e1bcf4a3bd32..bb1d4016fdb4d 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -605,7 +605,7 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper :meta private: """ - from airflow.sdk.definitions.baseoperator import ExecutorSafeguard + from airflow.sdk.bases.baseoperator import ExecutorSafeguard from airflow.sdk.definitions.mappedoperator import MappedOperator task_to_execute = task_instance.task diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 04bd6d974e139..12f386c3b45a4 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -133,7 +133,7 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq from airflow.models.baseoperator import BaseOperator as DBBaseOperator from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.settings import task_instance_mutation_hook diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 88549d65eb5f7..09bfcff7b02b8 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -398,12 +398,16 @@ def _process_row(row: Row) -> Any: def __getattr__(name: str): - if name == "BaseXCom" or name == "XCom": - from airflow.sdk.execution_time import xcom + if name == "BaseXCom": + from airflow.sdk.bases.xcom import BaseXCom - val = getattr(xcom, name) + globals()[name] = BaseXCom + return BaseXCom - globals()[name] = val - return val + if name == "XCom": + from airflow.sdk.execution_time.xcom import XCom + + globals()[name] = XCom + return XCom raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/airflow-core/src/airflow/sensors/base.py b/airflow-core/src/airflow/sensors/base.py index 9ac7c28c2d53f..71ae006f53437 100644 --- a/airflow-core/src/airflow/sensors/base.py +++ b/airflow-core/src/airflow/sensors/base.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from airflow.sdk.definitions.sensors.base import ( +from airflow.sdk.bases.sensor import ( BaseSensorOperator as BaseSensorOperator, PokeReturnValue as PokeReturnValue, poke_mode_only as poke_mode_only, diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index d2af3540b8038..8dd1ca10ec737 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -51,6 +51,7 @@ from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager +from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.asset import ( Asset, AssetAlias, @@ -63,7 +64,6 @@ AssetWatcher, BaseAsset, ) -from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup diff --git a/airflow-core/src/airflow/utils/task_group.py b/airflow-core/src/airflow/utils/task_group.py index 0712b3cb3aef5..c264274e7167e 100644 --- a/airflow-core/src/airflow/utils/task_group.py +++ b/airflow-core/src/airflow/utils/task_group.py @@ -32,8 +32,8 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(task := task_item_or_group, AbstractOperator): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index 61ff52b6bb272..95590bfbe13da 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -27,7 +27,8 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend +from airflow.sdk.bases.xcom import BaseXCom +from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 70029e56a05c3..ca99aff9d153e 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -77,8 +77,8 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.python import PythonSensor from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse +from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.definitions.asset import Asset, AssetAlias -from airflow.sdk.definitions.notifier import BaseNotifier from airflow.sdk.definitions.param import process_params from airflow.sdk.execution_time.comms import ( AssetEventsResult, diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 35391d6a6d9de..73ea34bf85289 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -29,7 +29,8 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend +from airflow.sdk.bases.xcom import BaseXCom +from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.settings import json from airflow.utils import timezone from airflow.utils.session import create_session diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 04b246d6af477..c435aa436a3dd 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2015,7 +2015,7 @@ def test_edge_info_serialization(self): @pytest.mark.db_test @pytest.mark.parametrize("mode", ["poke", "reschedule"]) def test_serialize_sensor(self, mode): - from airflow.sdk.definitions.sensors.base import BaseSensorOperator + from airflow.sdk.bases.sensor import BaseSensorOperator class DummySensor(BaseSensorOperator): def poke(self, context: Context): @@ -2032,7 +2032,7 @@ def poke(self, context: Context): @pytest.mark.parametrize("mode", ["poke", "reschedule"]) def test_serialize_mapped_sensor_has_reschedule_dep(self, mode): - from airflow.sdk.definitions.sensors.base import BaseSensorOperator + from airflow.sdk.bases.sensor import BaseSensorOperator from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep class DummySensor(BaseSensorOperator): diff --git a/dev/mypy/plugin/outputs.py b/dev/mypy/plugin/outputs.py index a3ba7351f556d..485a50cca0e1e 100644 --- a/dev/mypy/plugin/outputs.py +++ b/dev/mypy/plugin/outputs.py @@ -25,7 +25,7 @@ OUTPUT_PROPERTIES = { "airflow.models.baseoperator.BaseOperator.output", "airflow.models.mappedoperator.MappedOperator.output", - "airflow.sdk.definitions.baseoperator.BaseOperator.output", + "airflow.sdk.bases.baseoperator.BaseOperator.output", } TASK_CALL_FUNCTIONS = { diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 3b5aa6571337d..d645a100d042b 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -46,7 +46,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState - from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.timetables.base import DataInterval diff --git a/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py b/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py index 91ce1cdb71d96..e58c7bc7c7329 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py +++ b/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py @@ -22,9 +22,9 @@ from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: - from airflow.sdk.definitions.notifier import BaseNotifier + from airflow.sdk.bases.notifier import BaseNotifier elif AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.notifier import BaseNotifier + from airflow.sdk.bases.notifier import BaseNotifier else: from airflow.notifications.basenotifier import BaseNotifier diff --git a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py index c318698af3d4e..79bedc3977eb5 100644 --- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py @@ -37,7 +37,7 @@ from airflow.sdk.execution_time.comms import XComResult if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.execution_time.xcom import BaseXCom + from airflow.sdk.bases.xcom import BaseXCom else: from airflow.models.xcom import BaseXCom # type: ignore[no-redef] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index 72d4e7fc160a7..d99ba3d160a4e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -31,7 +31,6 @@ from google.cloud.metastore_v1.types.metastore import DatabaseDumpSpec, Restore from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.links.storage import StorageLink diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py index 1317e84869c5c..794681b413573 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py @@ -38,7 +38,7 @@ T = TypeVar("T", bound="DAG | Operator") if TYPE_CHECKING: - from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator as SdkBaseOperator log = logging.getLogger(__name__) diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index 0445964fee39d..14255f28127b5 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -76,7 +76,7 @@ def hook_lineage_collector(): if AIRFLOW_V_3_0_PLUS: from airflow.sdk.api.datamodels._generated import BundleInfo, TaskInstance as SDKTaskInstance - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py index 970d5d50aa766..83a901bae99d5 100755 --- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py @@ -30,7 +30,7 @@ from common_precommit_utils import AIRFLOW_CORE_SOURCES_PATH, AIRFLOW_TASK_SDK_SOURCES_PATH, console BASEOPERATOR_PY = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "models" / "baseoperator.py" -SDK_BASEOPERATOR_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "baseoperator.py" +SDK_BASEOPERATOR_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "bases" / "baseoperator.py" SDK_MAPPEDOPERATOR_PY = ( AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "mappedoperator.py" ) diff --git a/task-sdk/src/airflow/sdk/definitions/sensors/__init__.py b/task-sdk/__init__.py similarity index 99% rename from task-sdk/src/airflow/sdk/definitions/sensors/__init__.py rename to task-sdk/__init__.py index 217e5db960782..13a83393a9124 100644 --- a/task-sdk/src/airflow/sdk/definitions/sensors/__init__.py +++ b/task-sdk/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 25ef125a8c010..a7c768531b0da 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -52,18 +52,18 @@ __version__ = "1.0.0.alpha1" if TYPE_CHECKING: + from airflow.sdk.bases.baseoperator import BaseOperator, chain, chain_linear, cross_downstream + from airflow.sdk.bases.notifier import BaseNotifier + from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk.bases.sensor import BaseSensorOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata - from airflow.sdk.definitions.baseoperator import BaseOperator, chain, chain_linear, cross_downstream - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label - from airflow.sdk.definitions.notifier import BaseNotifier from airflow.sdk.definitions.param import Param - from airflow.sdk.definitions.sensors.base import BaseSensorOperator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.definitions.template import literal from airflow.sdk.definitions.variable import Variable @@ -76,9 +76,9 @@ "AssetAny": ".definitions.asset", "AssetWatcher": ".definitions.asset", "BaseNotifier": ".definitions.notifier", - "BaseOperator": ".definitions.baseoperator", - "BaseOperatorLink": ".definitions.baseoperatorlink", - "BaseSensorOperator": ".definitions.sensors.base", + "BaseOperator": ".bases.baseoperator", + "BaseOperatorLink": ".bases.operatorlink", + "BaseSensorOperator": ".bases.sensor", "Connection": ".definitions.connection", "Context": ".definitions.context", "DAG": ".definitions.dag", @@ -90,9 +90,9 @@ "Variable": ".definitions.variable", "XComArg": ".definitions.xcom_arg", "asset": ".definitions.asset.decorators", - "chain": ".definitions.baseoperator", - "chain_linear": ".definitions.baseoperator", - "cross_downstream": ".definitions.baseoperator", + "chain": ".bases.baseoperator", + "chain_linear": ".bases.baseoperator", + "cross_downstream": ".bases.baseoperator", "dag": ".definitions.dag", "get_current_context": ".definitions.context", "get_parsing_context": ".definitions.context", diff --git a/task-sdk/src/airflow/sdk/bases/__init__.py b/task-sdk/src/airflow/sdk/bases/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py b/task-sdk/src/airflow/sdk/bases/baseoperator.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/baseoperator.py rename to task-sdk/src/airflow/sdk/bases/baseoperator.py diff --git a/task-sdk/src/airflow/sdk/definitions/notifier.py b/task-sdk/src/airflow/sdk/bases/notifier.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/notifier.py rename to task-sdk/src/airflow/sdk/bases/notifier.py diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperatorlink.py b/task-sdk/src/airflow/sdk/bases/operatorlink.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/baseoperatorlink.py rename to task-sdk/src/airflow/sdk/bases/operatorlink.py diff --git a/task-sdk/src/airflow/sdk/definitions/sensors/base.py b/task-sdk/src/airflow/sdk/bases/sensor.py similarity index 99% rename from task-sdk/src/airflow/sdk/definitions/sensors/base.py rename to task-sdk/src/airflow/sdk/bases/sensor.py index 7e89e2550b20d..e896fb9f538bb 100644 --- a/task-sdk/src/airflow/sdk/definitions/sensors/base.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -37,7 +37,7 @@ TaskDeferralTimeout, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.utils import timezone if TYPE_CHECKING: diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py new file mode 100644 index 0000000000000..3376355ff0745 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +import structlog + +from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult + +log = structlog.get_logger(logger_name="task") + + +class BaseXCom: + """BaseXcom is an interface now to interact with XCom backends.""" + + @classmethod + def set( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + _mapped_length: int | None = None, + ) -> None: + """ + Store an XCom value. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + value = cls.serialize_value( + value=value, + key=key, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + ) + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + mapped_length=_mapped_length, + ), + ) + + @classmethod + def _set_xcom_in_db( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + ) -> None: + """ + Store an XCom value directly in the metadata database. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + @classmethod + def get_value( + cls, + *, + ti_key: Any, + key: str, + ) -> Any: + """ + Retrieve an XCom value for a task instance. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + :param ti_key: The TaskInstanceKey to look up the XCom for. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + return cls.get_one( + key=key, + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + ) + + @classmethod + def _get_xcom_db_ref( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + ) -> XComResult: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + return msg + + @classmethod + def get_one( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + include_prior_dates: bool = False, + ) -> Any | None: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param include_prior_dates: If *False* (default), only XCom from the + specified DAG run is returned. If *True*, the latest matching XCom is + returned regardless of the run it belongs to. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + if msg.value is not None: + return cls.deserialize_value(msg) + return None + + @staticmethod + def serialize_value( + value: Any, + *, + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> str: + """Serialize XCom value to JSON str.""" + from airflow.serialization.serde import serialize + + # return back the value for BaseXCom, custom backends will implement this + return serialize(value) # type: ignore[return-value] + + @staticmethod + def deserialize_value(result) -> Any: + """Deserialize XCom value from str objects.""" + from airflow.serialization.serde import deserialize + + return deserialize(result.value) + + @classmethod + def purge(cls, xcom: XComResult, *args) -> None: + """Purge an XCom entry from underlying storage implementations.""" + pass + + @classmethod + def delete( + cls, + key: str, + task_id: str, + dag_id: str, + run_id: str, + map_index: int | None = None, + ) -> None: + """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + xcom_result = cls._get_xcom_db_ref( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ) + cls.purge(xcom_result) # type: ignore[call-arg] + SUPERVISOR_COMMS.send_request( + log=log, + msg=DeleteXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + ), + ) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 944813bce1c24..b5b68d56330ab 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -41,8 +41,8 @@ if TYPE_CHECKING: import jinja2 - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink + from airflow.sdk.bases.baseoperator import BaseOperator + from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py index 93fd9431cbe38..71b88dc400028 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.edges import EdgeModifier diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index 93140f0a07c96..7fab2f1919b39 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -122,7 +122,7 @@ def _set_relatives( edge_modifier: EdgeModifier | None = None, ) -> None: """Set relatives for the task or task list.""" - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator if not isinstance(task_or_task_list, Sequence): diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index ee6a48cd904b9..03fbb6a7763fa 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -25,7 +25,7 @@ from datetime import datetime from airflow.models.operator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.context import InletEventsAccessors from airflow.sdk.types import ( diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 64a076346a001..dd2f4cef71dc4 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -51,10 +51,10 @@ ParamValidationError, TaskNotFound, ) +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.types import NOTSET from airflow.sdk.definitions.asset import AssetAll, BaseAsset -from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.param import DagParam, ParamsDict from airflow.timetables.base import Timetable diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 06bf309f6ce15..7416967d42cb0 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -69,8 +69,8 @@ OperatorExpandKwargsArgument, ) from airflow.models.xcom_arg import XComArg - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink + from airflow.sdk.bases.baseoperator import BaseOperator + from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 8152b76363ff8..c03c50cc4cd52 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -41,9 +41,9 @@ if TYPE_CHECKING: from airflow.models.expandinput import ExpandInput + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d080cc7ff3b13..c5c124cf3c81b 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -34,7 +34,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index d9b865c18b1fc..5ce223962a240 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -43,7 +43,7 @@ from uuid import UUID from airflow.sdk import Variable - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.execution_time.comms import ( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f64582f5cbcb1..4d5e9dcd753a9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -46,10 +46,10 @@ TerminalTIState, TIRunContext, ) +from airflow.sdk.bases.baseoperator import BaseOperator, ExecutorSafeguard from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef -from airflow.sdk.definitions.baseoperator import BaseOperator, ExecutorSafeguard from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.exceptions import ErrorType diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py b/task-sdk/src/airflow/sdk/execution_time/xcom.py index abb964907f196..536d10e22c4bd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/xcom.py +++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py @@ -14,302 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations -from typing import Any - -import structlog - from airflow.configuration import conf -from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult - -log = structlog.get_logger(logger_name="task") - - -class BaseXCom: - """BaseXcom is an interface now to interact with XCom backends.""" - - @classmethod - def set( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, - _mapped_length: int | None = None, - ) -> None: - """ - Store an XCom value. - - :param key: Key to store the XCom. - :param value: XCom value to store. - :param dag_id: DAG ID. - :param task_id: Task ID. - :param run_id: DAG run ID for the task. - :param map_index: Optional map index to assign XCom for a mapped task. - The default is ``-1`` (set for a non-mapped task). - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - value = cls.serialize_value( - value=value, - key=key, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, - ) - - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - mapped_length=_mapped_length, - ), - ) - - @classmethod - def _set_xcom_in_db( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, - ) -> None: - """ - Store an XCom value directly in the metadata database. - - :param key: Key to store the XCom. - :param value: XCom value to store. - :param dag_id: DAG ID. - :param task_id: Task ID. - :param run_id: DAG run ID for the task. - :param map_index: Optional map index to assign XCom for a mapped task. - The default is ``-1`` (set for a non-mapped task). - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - @classmethod - def get_value( - cls, - *, - ti_key: Any, - key: str, - ) -> Any: - """ - Retrieve an XCom value for a task instance. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - :param ti_key: The TaskInstanceKey to look up the XCom for. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - """ - return cls.get_one( - key=key, - task_id=ti_key.task_id, - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - map_index=ti_key.map_index, - ) - - @classmethod - def _get_xcom_db_ref( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, - ) -> XComResult: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - .. seealso:: ``get_value()`` is a convenience function if you already - have a structured TaskInstance or TaskInstanceKey object available. - - :param run_id: DAG run ID for the task. - :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to - remove the filter. - :param task_id: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param map_index: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComResult): - raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") - - return msg - - @classmethod - def get_one( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, - include_prior_dates: bool = False, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - .. seealso:: ``get_value()`` is a convenience function if you already - have a structured TaskInstance or TaskInstanceKey object available. - - :param run_id: DAG run ID for the task. - :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to - remove the filter. - :param task_id: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param map_index: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - :param include_prior_dates: If *False* (default), only XCom from the - specified DAG run is returned. If *True*, the latest matching XCom is - returned regardless of the run it belongs to. - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComResult): - raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") - - if msg.value is not None: - return cls.deserialize_value(msg) - return None - - @staticmethod - def serialize_value( - value: Any, - *, - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - run_id: str | None = None, - map_index: int | None = None, - ) -> str: - """Serialize XCom value to JSON str.""" - from airflow.serialization.serde import serialize - - # return back the value for BaseXCom, custom backends will implement this - return serialize(value) # type: ignore[return-value] - - @staticmethod - def deserialize_value(result) -> Any: - """Deserialize XCom value from str objects.""" - from airflow.serialization.serde import deserialize - - return deserialize(result.value) - - @classmethod - def purge(cls, xcom: XComResult, *args) -> None: - """Purge an XCom entry from underlying storage implementations.""" - pass - - @classmethod - def delete( - cls, - key: str, - task_id: str, - dag_id: str, - run_id: str, - map_index: int | None = None, - ) -> None: - """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - xcom_result = cls._get_xcom_db_ref( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ) - cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - ), - ) +from airflow.sdk.bases.xcom import BaseXCom def resolve_xcom_backend(): @@ -318,7 +26,7 @@ def resolve_xcom_backend(): :returns: returns the custom XCom class if configured. """ - clazz = conf.getimport("core", "xcom_backend", fallback="airflow.sdk.execution_time.xcom.BaseXCom") + clazz = conf.getimport("core", "xcom_backend", fallback="airflow.sdk.bases.xcom.BaseXCom") if not clazz: return BaseXCom if not issubclass(clazz, BaseXCom): diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 6c3c40f2ab93d..e7e2036aa5c2b 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -29,8 +29,8 @@ from collections.abc import Iterator from datetime import datetime + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator diff --git a/task-sdk/tests/task_sdk/bases/__init__.py b/task-sdk/tests/task_sdk/bases/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/task-sdk/tests/task_sdk/bases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task-sdk/tests/task_sdk/definitions/notifier/test_notifier.txt b/task-sdk/tests/task_sdk/bases/notifier/test_notifier.txt similarity index 100% rename from task-sdk/tests/task_sdk/definitions/notifier/test_notifier.txt rename to task-sdk/tests/task_sdk/bases/notifier/test_notifier.txt diff --git a/task-sdk/tests/task_sdk/definitions/test_baseoperator.py b/task-sdk/tests/task_sdk/bases/test_baseoperator.py similarity index 99% rename from task-sdk/tests/task_sdk/definitions/test_baseoperator.py rename to task-sdk/tests/task_sdk/bases/test_baseoperator.py index f87c82e6d4e83..969030a22ca86 100644 --- a/task-sdk/tests/task_sdk/definitions/test_baseoperator.py +++ b/task-sdk/tests/task_sdk/bases/test_baseoperator.py @@ -29,7 +29,7 @@ import structlog from airflow.decorators import task as task_decorator -from airflow.sdk.definitions.baseoperator import ( +from airflow.sdk.bases.baseoperator import ( BaseOperator, BaseOperatorMeta, ExecutorSafeguard, diff --git a/task-sdk/tests/task_sdk/definitions/test_notifier.py b/task-sdk/tests/task_sdk/bases/test_notifier.py similarity index 98% rename from task-sdk/tests/task_sdk/definitions/test_notifier.py rename to task-sdk/tests/task_sdk/bases/test_notifier.py index cc8b5f9659b19..b8cedaa518831 100644 --- a/task-sdk/tests/task_sdk/definitions/test_notifier.py +++ b/task-sdk/tests/task_sdk/bases/test_notifier.py @@ -24,8 +24,8 @@ import pytest from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.definitions.dag import DAG -from airflow.sdk.definitions.notifier import BaseNotifier if TYPE_CHECKING: from airflow.sdk.definitions.context import Context diff --git a/task-sdk/tests/task_sdk/definitions/sensors/test_base.py b/task-sdk/tests/task_sdk/bases/test_sensor.py similarity index 99% rename from task-sdk/tests/task_sdk/definitions/sensors/test_base.py rename to task-sdk/tests/task_sdk/bases/test_sensor.py index ebd815d44bf0a..2c0a82783d4d6 100644 --- a/task-sdk/tests/task_sdk/definitions/sensors/test_base.py +++ b/task-sdk/tests/task_sdk/bases/test_sensor.py @@ -34,8 +34,8 @@ ) from airflow.models.trigger import TriggerFailureReason from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.sdk.definitions.dag import DAG -from airflow.sdk.definitions.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.sdk.execution_time.comms import RescheduleTask, TaskRescheduleStartDate from airflow.utils import timezone from airflow.utils.state import State diff --git a/task-sdk/tests/task_sdk/dags/super_basic.py b/task-sdk/tests/task_sdk/dags/super_basic.py index b5a50785bcebf..2cccb9ab4c647 100644 --- a/task-sdk/tests/task_sdk/dags/super_basic.py +++ b/task-sdk/tests/task_sdk/dags/super_basic.py @@ -17,7 +17,7 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import dag diff --git a/task-sdk/tests/task_sdk/dags/super_basic_run.py b/task-sdk/tests/task_sdk/dags/super_basic_run.py index 87d2a6820226b..e178e7acc8935 100644 --- a/task-sdk/tests/task_sdk/dags/super_basic_run.py +++ b/task-sdk/tests/task_sdk/dags/super_basic_run.py @@ -17,7 +17,7 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import dag diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index 71d888d3af10e..07a99221d496a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -23,7 +23,7 @@ import pytest from airflow.exceptions import DuplicateTaskIdFound -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag as dag_decorator from airflow.sdk.definitions.param import DagParam, Param, ParamsDict diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index a75cfacf6cb46..3c780551c6aea 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -25,7 +25,7 @@ import pytest from airflow.sdk.api.datamodels._generated import TerminalTIState -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.xcom_arg import XComArg diff --git a/task-sdk/tests/task_sdk/definitions/test_mixins.py b/task-sdk/tests/task_sdk/definitions/test_mixins.py index 83b4d6eabefdf..b7ffa974758ee 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mixins.py +++ b/task-sdk/tests/task_sdk/definitions/test_mixins.py @@ -22,7 +22,7 @@ import pytest from airflow.decorators import setup, task, teardown -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG