diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 3611dcdee3dd6..62da7d0a27771 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4641,7 +4641,7 @@ paths: required: false schema: type: boolean - default: true + default: false title: Stringify responses: '200': diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py index ecabb80ea36dd..4b1361d402738 100644 --- a/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -37,7 +37,8 @@ from airflow.api_fastapi.core_api.security import ReadableXComFilterDep, requires_access_dag from airflow.api_fastapi.logging.decorators import action_logging from airflow.exceptions import TaskNotFound -from airflow.models import DAG, DagRun as DR, XCom +from airflow.models import DAG, DagRun as DR +from airflow.models.xcom import XComModel from airflow.settings import conf xcom_router = AirflowRouter( @@ -63,7 +64,7 @@ def get_xcom_entry( session: SessionDep, map_index: Annotated[int, Query(ge=-1)] = -1, deserialize: Annotated[bool, Query()] = False, - stringify: Annotated[bool, Query()] = True, + stringify: Annotated[bool, Query()] = False, ) -> XComResponseNative | XComResponseString: """Get an XCom entry.""" if deserialize: @@ -71,14 +72,17 @@ def get_xcom_entry( raise HTTPException( status.HTTP_400_BAD_REQUEST, "XCom deserialization is disabled in configuration." ) - query = select(XCom, XCom.value) + query = select(XComModel, XComModel.value) else: - query = select(XCom) + query = select(XComModel) query = query.where( - XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key, XCom.map_index == map_index + XComModel.dag_id == dag_id, + XComModel.task_id == task_id, + XComModel.key == xcom_key, + XComModel.map_index == map_index, ) - query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) + query = query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id)) query = query.where(DR.run_id == dag_run_id) if deserialize: @@ -90,6 +94,8 @@ def get_xcom_entry( raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found") if deserialize: + from airflow.sdk.execution_time.xcom import XCom + xcom, value = item xcom_stub = copy.copy(xcom) xcom_stub.value = value @@ -127,19 +133,19 @@ def get_xcom_entries( This endpoint allows specifying `~` as the dag_id, dag_run_id, task_id to retrieve XCom entries for all DAGs. """ - query = select(XCom) + query = select(XComModel) if dag_id != "~": - query = query.where(XCom.dag_id == dag_id) - query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) + query = query.where(XComModel.dag_id == dag_id) + query = query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id)) if task_id != "~": - query = query.where(XCom.task_id == task_id) + query = query.where(XComModel.task_id == task_id) if dag_run_id != "~": query = query.where(DR.run_id == dag_run_id) if map_index is not None: - query = query.where(XCom.map_index == map_index) + query = query.where(XComModel.map_index == map_index) if xcom_key is not None: - query = query.where(XCom.key == xcom_key) + query = query.where(XComModel.key == xcom_key) query, total_entries = paginated_select( statement=query, @@ -148,7 +154,9 @@ def get_xcom_entries( limit=limit, session=session, ) - query = query.order_by(XCom.dag_id, XCom.task_id, XCom.run_id, XCom.map_index, XCom.key) + query = query.order_by( + XComModel.dag_id, XComModel.task_id, XComModel.run_id, XComModel.map_index, XComModel.key + ) xcoms = session.scalars(query) return XComCollectionResponse(xcom_entries=xcoms, total_entries=total_entries) @@ -197,38 +205,48 @@ def create_xcom_entry( ) # Check existing XCom - if XCom.get_one( + already_existing_query = XComModel.get_many( key=request_body.key, - task_id=task_id, - dag_id=dag_id, + task_ids=task_id, + dag_ids=dag_id, run_id=dag_run_id, - map_index=request_body.map_index, + map_indexes=request_body.map_index, session=session, - ): + ) + result = already_existing_query.with_entities(XComModel.value).first() + if result: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"The XCom with key: `{request_body.key}` with mentioned task instance already exists.", ) - # Create XCom entry - XCom.set( - dag_id=dag_id, - task_id=task_id, - run_id=dag_run_id, + try: + value = XComModel.serialize_value(request_body.value) + except (ValueError, TypeError): + raise HTTPException( + status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{request_body.key}`" + ) + + new = XComModel( + dag_run_id=dag_run.id, key=request_body.key, - value=XCom.serialize_value(request_body.value), + value=value, + run_id=dag_run_id, + task_id=task_id, + dag_id=dag_id, map_index=request_body.map_index, - session=session, ) + session.add(new) + session.flush() xcom = session.scalar( - select(XCom) + select(XComModel) .filter( - XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.run_id == dag_run_id, - XCom.key == request_body.key, - XCom.map_index == request_body.map_index, + XComModel.dag_id == dag_id, + XComModel.task_id == task_id, + XComModel.run_id == dag_run_id, + XComModel.key == request_body.key, + XComModel.map_index == request_body.map_index, ) .limit(1) ) @@ -260,15 +278,15 @@ def update_xcom_entry( ) -> XComResponseNative: """Update an existing XCom entry.""" # Check if XCom entry exists - xcom_new_value = XCom.serialize_value(patch_body.value) + xcom_new_value = XComModel.serialize_value(patch_body.value) xcom_entry = session.scalar( - select(XCom) + select(XComModel) .where( - XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.run_id == dag_run_id, - XCom.key == xcom_key, - XCom.map_index == patch_body.map_index, + XComModel.dag_id == dag_id, + XComModel.task_id == task_id, + XComModel.run_id == dag_run_id, + XComModel.key == xcom_key, + XComModel.map_index == patch_body.map_index, ) .limit(1) ) @@ -280,6 +298,6 @@ def update_xcom_entry( ) # Update XCom entry - xcom_entry.value = XCom.serialize_value(xcom_new_value) + xcom_entry.value = XComModel.serialize_value(xcom_new_value) return XComResponseNative.model_validate(xcom_entry) diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index d84bf3410bb54..b10f7ac8549fc 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -43,7 +43,7 @@ from airflow.models.dag import DagModel, DagRun, DagTag from airflow.models.dagwarning import DagWarning from airflow.models.taskinstance import TaskInstance as TI -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel if TYPE_CHECKING: from sqlalchemy.sql import Select @@ -132,7 +132,7 @@ class PermittedXComFilter(PermittedDagFilter): """A parameter that filters the permitted XComs for the user.""" def to_orm(self, select: Select) -> Select: - return select.where(XCom.dag_id.in_(self.value)) + return select.where(XComModel.dag_id.in_(self.value)) class PermittedTagFilter(PermittedDagFilter): diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 32d34a07c2332..6e2157e3c7b2a 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -273,6 +273,9 @@ class TIRunContext(BaseModel): Can either be a "decorated" dict, or a string encrypted with the shared Fernet key. """ + xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)] + """List of Xcom keys that need to be cleared and purged on by the worker.""" + class PrevSuccessfulDagRunResponse(BaseModel): """Schema for response with previous successful DagRun information for Task Template Context.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index c666cc7ff98aa..07caee01797db 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -48,7 +48,7 @@ from airflow.models.taskinstance import TaskInstance as TI, _update_rtif from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState @@ -187,18 +187,22 @@ def ti_run( if not dr: raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.") - # Clear XCom data for the task instance since we are certain it is executing + # Send the keys to the SDK so that the client requests to clear those XComs from the server. + # The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends + # too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they + # can issue the purge. + # However, do not clear it for deferral + xcom_keys = [] if not ti.next_method: map_index = None if ti.map_index < 0 else ti.map_index - log.info("Clearing xcom data for task id: %s", ti_id_str) - XCom.clear( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=map_index, - session=session, + query = session.query(XComModel.key).filter_by( + dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id ) + if map_index is not None: + query = query.filter_by(map_index=map_index) + + xcom_keys = [row.key for row in session.execute(query).all()] task_reschedule_count = ( session.query( @@ -216,6 +220,7 @@ def ti_run( # TODO: Add variables and connections that are needed (and has perms) for the task variables=[], connections=[], + xcom_keys_to_clear=xcom_keys, ) # Only set if they are non-null diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index 3b023836ea864..808f6deda2bdd 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -22,6 +22,7 @@ from fastapi import Body, Depends, HTTPException, Query, Response, status from pydantic import JsonValue +from sqlalchemy import delete from sqlalchemy.sql.selectable import Select from airflow.api_fastapi.common.db.common import SessionDep @@ -30,7 +31,7 @@ from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.taskmap import TaskMap -from airflow.models.xcom import BaseXCom +from airflow.models.xcom import XComModel from airflow.utils.db import get_query_count # TODO: Add dependency on JWT token @@ -62,7 +63,7 @@ async def xcom_query( }, ) - query = BaseXCom.get_many( + query = XComModel.get_many( run_id=run_id, key=key, task_ids=task_id, @@ -126,7 +127,7 @@ def get_xcom( """Get an Airflow XCom from database - not other XCom Backends.""" # The xcom_query allows no map_index to be passed. This endpoint should always return just a single item, # so we override that query value - xcom_query = xcom_query.filter(BaseXCom.map_index == map_index) + xcom_query = xcom_query.filter(XComModel.map_index == map_index) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` @@ -222,18 +223,39 @@ def set_xcom( # TODO: Can/should we check if a client _hasn't_ provided this for an upstream of a mapped task? That # means loading the serialized dag and that seems like a relatively costly operation for minimal benefit # (the mapped task would fail in a moment as it can't be expanded anyway.) + from airflow.models.dagrun import DagRun + + if not run_id: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found") + + dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + if dag_run_id is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}") + + # Remove duplicate XComs and insert a new one. + session.execute( + delete(XComModel).where( + XComModel.key == key, + XComModel.run_id == run_id, + XComModel.task_id == task_id, + XComModel.dag_id == dag_id, + XComModel.map_index == map_index, + ) + ) - # We use `BaseXCom.set` to set XComs directly to the database, bypassing the XCom Backend. try: - BaseXCom.set( + # We expect serialised value from the caller - sdk, do not serialise in here + new = XComModel( + dag_run_id=dag_run_id, key=key, value=value, - dag_id=dag_id, - task_id=task_id, run_id=run_id, - session=session, + task_id=task_id, + dag_id=dag_id, map_index=map_index, ) + session.add(new) + session.flush() except TypeError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -246,6 +268,34 @@ def set_xcom( return {"message": "XCom successfully set"} +@router.delete( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, + description="Delete a single XCom Value", +) +def delete_xcom( + session: SessionDep, + token: deps.TokenDep, + dag_id: str, + run_id: str, + task_id: str, + key: str, +): + if not has_xcom_access(dag_id, run_id, task_id, key, token, write=True): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "reason": "access_denied", + "message": f"Task does not have access to delete XCom with key '{key}'", + }, + ) + + query = session.query(XComModel).where(XComModel.key == key).first() + session.delete(query) + session.commit() + return {"message": f"XCom with key: {key} successfully deleted."} + + def has_xcom_access( dag_id: str, run_id: str, task_id: str, xcom_key: str, token: TIToken, write: bool = False ) -> bool: diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 122d81f9dc788..f6804ecaf7ac8 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -401,7 +401,7 @@ core: version_added: 1.10.12 type: string example: "path.to.CustomXCom" - default: "airflow.models.xcom.BaseXCom" + default: "airflow.sdk.execution_time.xcom.BaseXCom" lazy_load_plugins: description: | By default Airflow plugins are lazily-loaded (only loaded when required). Set it to ``False``, diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 89f93e1f2a9d5..20e1c65df8af7 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -44,7 +44,7 @@ "TaskReschedule", "Trigger", "Variable", - "XCom", + "XComModel", "clear_task_instances", ] @@ -107,7 +107,7 @@ def __getattr__(name): "TaskReschedule": "airflow.models.taskreschedule", "Trigger": "airflow.models.trigger", "Variable": "airflow.models.variable", - "XCom": "airflow.models.xcom", + "XCom": "airflow.sdk.execution_time.xcom", "clear_task_instances": "airflow.models.taskinstance", } @@ -135,6 +135,6 @@ def __getattr__(name): from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.variable import Variable - from airflow.models.xcom import XCom from airflow.sdk import BaseOperatorLink from airflow.sdk.definitions.param import Param + from airflow.sdk.execution_time.xcom import XCom diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ec87342cdb7d8..259184a2febd8 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -102,7 +102,7 @@ from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule -from airflow.models.xcom import LazyXComSelectSequence, XCom +from airflow.models.xcom import LazyXComSelectSequence, XComModel from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook @@ -528,7 +528,7 @@ def _xcom_pull( if run_id is None: run_id = ti.run_id - query = XCom.get_many( + query = XComModel.get_many( key=key, run_id=run_id, dag_ids=dag_id, @@ -545,12 +545,12 @@ def _xcom_pull( # We are only pulling one single task. if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): first = query.with_entities( - XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value + XComModel.run_id, XComModel.task_id, XComModel.dag_id, XComModel.map_index, XComModel.value ).first() if first is None: # No matching XCom at all. return default if map_indexes is not None or first.map_index < 0: - return XCom.deserialize_value(first) + return XComModel.deserialize_value(first) # raise RuntimeError("Nothing should hit this anymore") @@ -560,24 +560,24 @@ def _xcom_pull( # Order return values to match task_ids and map_indexes ordering. ordering = [] if task_ids is None or isinstance(task_ids, str): - ordering.append(XCom.task_id) + ordering.append(XComModel.task_id) elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}: - ordering.append(case(task_id_whens, value=XCom.task_id)) + ordering.append(case(task_id_whens, value=XComModel.task_id)) else: - ordering.append(XCom.task_id) + ordering.append(XComModel.task_id) if map_indexes is None or isinstance(map_indexes, int): - ordering.append(XCom.map_index) + ordering.append(XComModel.map_index) elif isinstance(map_indexes, range): - order = XCom.map_index + order = XComModel.map_index if map_indexes.step < 0: order = order.desc() ordering.append(order) elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}: - ordering.append(case(map_index_whens, value=XCom.map_index)) + ordering.append(case(map_index_whens, value=XComModel.map_index)) else: - ordering.append(XCom.map_index) + ordering.append(XComModel.map_index) return LazyXComSelectSequence.from_select( - query.with_entities(XCom.value).order_by(None).statement, + query.with_entities(XComModel.value).order_by(None).statement, order_by=ordering, session=session, ) @@ -2139,7 +2139,7 @@ def _clear_xcom_data(ti: TaskInstance, session: Session = NEW_SESSION) -> None: map_index: int | None = None else: map_index = ti.map_index - XCom.clear( + XComModel.clear( dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, @@ -3332,7 +3332,7 @@ def xcom_push( :param key: Key to store the value under. :param value: Value to store. Only be JSON-serializable may be used otherwise. """ - XCom.set( + XComModel.set( key=key, value=value, task_id=self.task_id, @@ -3596,7 +3596,7 @@ def clear_db_references(self, session: Session): from airflow.models.renderedtifields import RenderedTaskInstanceFields tables: list[type[TaskInstanceDependencies]] = [ - XCom, + XComModel, RenderedTaskInstanceFields, TaskMap, ] diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 96683855ffb36..a8f1d3ee3102d 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import inspect import json import logging from collections.abc import Iterable @@ -37,15 +36,13 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Query, reconstructor, relationship +from sqlalchemy.orm import Query, relationship -from airflow.configuration import conf from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.utils import timezone from airflow.utils.db import LazySelectSequence from airflow.utils.helpers import is_container from airflow.utils.json import XComDecoder, XComEncoder -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime @@ -63,11 +60,9 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.models.taskinstancekey import TaskInstanceKey - -class BaseXCom(TaskInstanceDependencies, LoggingMixin): - """Base class for XCom objects.""" +class XComModel(TaskInstanceDependencies): + """XCom model class. Contains table and some utilities.""" __tablename__ = "xcom" @@ -105,26 +100,56 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin): dag_run = relationship( "DagRun", - primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)", + primaryjoin="XComModel.dag_run_id == foreign(DagRun.id)", uselist=False, lazy="joined", passive_deletes="all", ) logical_date = association_proxy("dag_run", "logical_date") - @reconstructor - def init_on_load(self): + @classmethod + @provide_session + def clear( + cls, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + session: Session = NEW_SESSION, + ) -> None: """ - Execute after the instance has been loaded from the DB or otherwise reconstituted; called by the ORM. + Clear all XCom data from the database for the given task instance. - i.e automatically deserialize Xcom value when loading from DB. + .. note:: This **will not** purge any data from a custom XCom backend. + + :param dag_id: ID of DAG to clear the XCom for. + :param task_id: ID of task to clear the XCom for. + :param run_id: ID of DAG run to clear the XCom for. + :param map_index: If given, only clear XCom from this particular mapped + task. The default ``None`` clears *all* XComs from the task. + :param session: Database session. If not given, a new session will be + created for this function. """ - self.value = self.orm_deserialize_value() + # Given the historic order of this function (logical_date was first argument) to add a new optional + # param we need to add default values for everything :( + if dag_id is None: + raise TypeError("clear() missing required argument: dag_id") + if task_id is None: + raise TypeError("clear() missing required argument: task_id") - def __repr__(self): - if self.map_index < 0: - return f'' - return f'' + if not run_id: + raise ValueError(f"run_id must be passed. Passed run_id={run_id}") + + query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) + if map_index is not None: + query = query.filter_by(map_index=map_index) + + for xcom in query: + # print(f"Clearing XCOM {xcom} with value {xcom.value}") + session.delete(xcom) + + session.commit() @classmethod @provide_session @@ -201,6 +226,7 @@ def set( cls.map_index == map_index, ) ) + new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'. dag_run_id=dag_run_id, key=key, @@ -213,98 +239,10 @@ def set( session.add(new) session.flush() - @staticmethod - @provide_session - def get_value( - *, - ti_key: TaskInstanceKey, - key: str | None = None, - session: Session = NEW_SESSION, - ) -> Any: - """ - Retrieve an XCom value for a task instance. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - :param ti_key: The TaskInstanceKey to look up the XCom for. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - :param session: Database session. If not given, a new session will be - created for this function. - """ - return BaseXCom.get_one( - key=key, - task_id=ti_key.task_id, - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - map_index=ti_key.map_index, - session=session, - ) - - @staticmethod - @provide_session - def get_one( - *, - key: str | None = None, - dag_id: str | None = None, - task_id: str | None = None, - run_id: str, - map_index: int | None = None, - session: Session = NEW_SESSION, - include_prior_dates: bool = False, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - .. seealso:: ``get_value()`` is a convenience function if you already - have a structured TaskInstance or TaskInstanceKey object available. - - :param run_id: DAG run ID for the task. - :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to - remove the filter. - :param task_id: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param map_index: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - :param include_prior_dates: If *False* (default), only XCom from the - specified DAG run is returned. If *True*, the latest matching XCom is - returned regardless of the run it belongs to. - :param session: Database session. If not given, a new session will be - created for this function. - """ - query = BaseXCom.get_many( - run_id=run_id, - key=key, - task_ids=task_id, - dag_ids=dag_id, - map_indexes=map_index, - include_prior_dates=include_prior_dates, - limit=1, - session=session, - ) - - result = query.with_entities(BaseXCom.value).first() - if result: - return XCom.deserialize_value(result) - return None - - @staticmethod + @classmethod @provide_session def get_many( + cls, *, run_id: str, key: str | None = None, @@ -342,101 +280,39 @@ def get_many( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - query = session.query(BaseXCom).join(BaseXCom.dag_run) + query = session.query(cls).join(XComModel.dag_run) if key: - query = query.filter(BaseXCom.key == key) + query = query.filter(XComModel.key == key) if is_container(task_ids): - query = query.filter(BaseXCom.task_id.in_(task_ids)) + query = query.filter(cls.task_id.in_(task_ids)) elif task_ids is not None: - query = query.filter(BaseXCom.task_id == task_ids) + query = query.filter(cls.task_id == task_ids) if is_container(dag_ids): - query = query.filter(BaseXCom.dag_id.in_(dag_ids)) + query = query.filter(cls.dag_id.in_(dag_ids)) elif dag_ids is not None: - query = query.filter(BaseXCom.dag_id == dag_ids) + query = query.filter(cls.dag_id == dag_ids) if isinstance(map_indexes, range) and map_indexes.step == 1: - query = query.filter( - BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop - ) + query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) elif is_container(map_indexes): - query = query.filter(BaseXCom.map_index.in_(map_indexes)) + query = query.filter(cls.map_index.in_(map_indexes)) elif map_indexes is not None: - query = query.filter(BaseXCom.map_index == map_indexes) + query = query.filter(cls.map_index == map_indexes) if include_prior_dates: dr = session.query(DagRun.logical_date).filter(DagRun.run_id == run_id).subquery() - query = query.filter(BaseXCom.logical_date <= dr.c.logical_date) + query = query.filter(cls.logical_date <= dr.c.logical_date) else: - query = query.filter(BaseXCom.run_id == run_id) + query = query.filter(cls.run_id == run_id) - query = query.order_by(DagRun.logical_date.desc(), BaseXCom.timestamp.desc()) + query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) if limit: return query.limit(limit) return query - @classmethod - @provide_session - def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None: - """Delete one or multiple XCom entries.""" - if isinstance(xcoms, XCom): - xcoms = [xcoms] - for xcom in xcoms: - if not isinstance(xcom, XCom): - raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}") - XCom.purge(xcom, session) - session.delete(xcom) - session.commit() - - @staticmethod - def purge(xcom: XCom, session: Session) -> None: - """Purge an XCom entry from underlying storage implementations.""" - pass - - @staticmethod - @provide_session - def clear( - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, - session: Session = NEW_SESSION, - ) -> None: - """ - Clear all XCom data from the database for the given task instance. - - :param dag_id: ID of DAG to clear the XCom for. - :param task_id: ID of task to clear the XCom for. - :param run_id: ID of DAG run to clear the XCom for. - :param map_index: If given, only clear XCom from this particular mapped - task. The default ``None`` clears *all* XComs from the task. - :param session: Database session. If not given, a new session will be - created for this function. - """ - # Given the historic order of this function (logical_date was first argument) to add a new optional - # param we need to add default values for everything :( - if dag_id is None: - raise TypeError("clear() missing required argument: dag_id") - if task_id is None: - raise TypeError("clear() missing required argument: task_id") - - if not run_id: - raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - - query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) - if map_index is not None: - query = query.filter_by(map_index=map_index) - - for xcom in query: - # print(f"Clearing XCOM {xcom} with value {xcom.value}") - XCom.purge(xcom, session) - session.delete(xcom) - - session.commit() - @staticmethod def serialize_value( value: Any, @@ -454,31 +330,12 @@ def serialize_value( raise ValueError("XCom value must be JSON serializable") @staticmethod - def _deserialize_value(result: XCom, orm: bool) -> Any: - object_hook = None - if orm: - object_hook = XComDecoder.orm_object_hook - + def deserialize_value(result) -> Any: + """Deserialize XCom value from str objects.""" if result.value is None: return None - return json.loads(result.value, cls=XComDecoder, object_hook=object_hook) - - @staticmethod - def deserialize_value(result: XCom) -> Any: - """Deserialize XCom value from str or pickle object.""" - return BaseXCom._deserialize_value(result, False) - - def orm_deserialize_value(self) -> Any: - """ - Deserialize method which is used to reconstruct ORM XCom object. - - This method should be overridden in custom XCom backends to avoid - unnecessary request or other resource consuming operations when - creating XCom orm model. This is used when viewing XCom listing - in the webserver, for example. - """ - return BaseXCom._deserialize_value(self, True) + return json.loads(result.value, cls=XComDecoder) class LazyXComSelectSequence(LazySelectSequence[Any]): @@ -490,44 +347,20 @@ class LazyXComSelectSequence(LazySelectSequence[Any]): @staticmethod def _rebuild_select(stmt: TextClause) -> Select: - return select(XCom.value).from_statement(stmt) + return select(XComModel.value).from_statement(stmt) @staticmethod def _process_row(row: Row) -> Any: - return XCom.deserialize_value(row) + return XComModel.deserialize_value(row) -def _get_function_params(function) -> list[str]: - """ - Return the list of variables names of a function. +def __getattr__(name: str): + if name == "BaseXCom" or name == "XCom": + from airflow.sdk.execution_time import xcom - :param function: The function to inspect - """ - parameters = inspect.signature(function).parameters - bound_arguments = [ - name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) - ] - return bound_arguments + val = getattr(xcom, name) + globals()[name] = val + return val -def resolve_xcom_backend() -> type[BaseXCom]: - """ - Resolve custom XCom class. - - Confirm that custom XCom class extends the BaseXCom. - Compare the function signature of the custom XCom serialize_value to the base XCom serialize_value. - """ - clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}") - if not clazz: - return BaseXCom - if not issubclass(clazz, BaseXCom): - raise TypeError( - f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." - ) - return clazz - - -if TYPE_CHECKING: - XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type". -else: - XCom = resolve_xcom_backend() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 1d885fb5bd1b0..cfda9295cec26 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -114,7 +114,7 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Ses def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCom + from airflow.models.xcom import XComModel dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id @@ -136,12 +136,12 @@ def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): ) if unfinished_ti_exists: return None # Not all of the expanded tis are done yet. - query = select(func.count(XCom.map_index)).where( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.task_id == task_id, - XCom.map_index >= 0, - XCom.key == XCOM_RETURN_KEY, + query = select(func.count(XComModel.map_index)).where( + XComModel.dag_id == dag_id, + XComModel.run_id == run_id, + XComModel.task_id == task_id, + XComModel.map_index >= 0, + XComModel.key == XCOM_RETURN_KEY, ) else: query = select(TaskMap.length).where( diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8451d6850acbd..df09356a40ec6 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -48,7 +48,7 @@ ) from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.models.xcom import BaseXCom +from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( @@ -2019,13 +2019,13 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: self.log.info( "Attempting to retrieve link from XComs with key: %s for task id: %s", self.xcom_key, ti_key ) - value = BaseXCom.get_one( + value = XComModel.get_many( key=self.xcom_key, run_id=ti_key.run_id, - dag_id=ti_key.dag_id, - task_id=ti_key.task_id, - map_index=ti_key.map_index, - ) + dag_ids=ti_key.dag_id, + task_ids=ti_key.task_id, + map_indexes=ti_key.map_index, + ).first() if not value: self.log.debug( "No link with name: %s present in XCom as key: %s, returning empty link", @@ -2033,4 +2033,4 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: self.xcom_key, ) return "" - return value + return XComModel.deserialize_value(value) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index d3fd4c22117c2..b661d0ccbbb4e 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -29,6 +29,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar +from unittest import mock import pytest import time_machine @@ -1058,13 +1059,18 @@ def __call__( return self def cleanup(self): - from airflow.models import DagModel, DagRun, TaskInstance, XCom + from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskmap import TaskMap from airflow.utils.retries import run_with_db_retries from tests_common.test_utils.compat import AssetEvent + if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom + else: + from airflow.models.xcom import XCom + for attempt in run_with_db_retries(logger=self.log): with attempt: dag_ids = list(self.dagbag.dag_ids) @@ -1830,3 +1836,17 @@ def override_caplog(request): import airflow.logging_config airflow.logging_config.configure_logging() + + +@pytest.fixture +def mock_supervisor_comms(): + # for back-compat + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + + if not AIRFLOW_V_3_0_PLUS: + yield None + return + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index 374d65d63da52..2948c1941a450 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -33,7 +33,6 @@ TaskReschedule, Trigger, Variable, - XCom, ) from airflow.models.dag import DagOwnerAttributes from airflow.models.dagcode import DagCode @@ -56,6 +55,11 @@ if TYPE_CHECKING: from pathlib import Path +if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom +else: + from airflow.models.xcom import XCom + def _bootstrap_dagbag(): from airflow.models.dag import DAG diff --git a/devel-common/src/tests_common/test_utils/mock_operators.py b/devel-common/src/tests_common/test_utils/mock_operators.py index 4f43f479f69e9..8854ea9cb1c1c 100644 --- a/devel-common/src/tests_common/test_utils/mock_operators.py +++ b/devel-common/src/tests_common/test_utils/mock_operators.py @@ -21,13 +21,18 @@ import attr from airflow.models.baseoperator import BaseOperator -from airflow.models.xcom import XCom from tests_common.test_utils.compat import BaseOperatorLink +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.sdk.definitions.context import Context +if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom +else: + from airflow.models.xcom import XCom # type: ignore[no-redef] + class MockOperator(BaseOperator): """Operator for testing purposes.""" @@ -89,9 +94,17 @@ def xcom_key(self) -> str: return f"bigquery_{self.index + 1}" def get_link(self, operator, *, ti_key): - search_queries = XCom.get_one( - task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query" - ) + if AIRFLOW_V_3_0_PLUS: + search_queries = XCom.get_many( + task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query" + ).first() + + search_queries = XCom.deserialize_value(search_queries) + else: + search_queries = XCom.get_one( + task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query" + ) + if not search_queries: return None if len(search_queries) < self.index: @@ -106,13 +119,23 @@ class CustomOpLink(BaseOperatorLink): name = "Google Custom" def get_link(self, operator, *, ti_key): - search_query = XCom.get_one( - task_id=ti_key.task_id, - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - map_index=ti_key.map_index, - key="search_query", - ) + if AIRFLOW_V_3_0_PLUS: + search_query = XCom.get_many( + task_ids=ti_key.task_id, + dag_ids=ti_key.dag_id, + run_id=ti_key.run_id, + map_indexes=ti_key.map_index, + key="search_query", + ).first() + search_query = XCom.deserialize_value(search_query) + else: + search_query = XCom.get_one( + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + key="search_query", + ) if not search_query: return None return f"http://google.com/custom_base_link?search={search_query}" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py index 47330c473c80b..e909776bdffd8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, ClassVar -from airflow.models import XCom from airflow.providers.amazon.aws.utils.suppress import return_on_error from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS @@ -30,7 +29,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py index 94c9a6152bc11..5f4bafc89cafd 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py @@ -17,13 +17,28 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestAthenaQueryResultsLink(BaseAwsLinksTestCase): link_class = AthenaQueryResultsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=AthenaQueryResultsLink.key, + value={ + "region_name": "eu-west-1", + "aws_domain": AthenaQueryResultsLink.get_aws_domain("aws"), + "aws_partition": "aws", + "query_execution_id": "00000000-0000-0000-0000-000000000000", + }, + ) + self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/athena/home" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py index 1749d51e36e47..7423d8579773f 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py @@ -22,7 +22,6 @@ import pytest -from airflow.models.xcom import XCom from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.serialization.serialized_objects import SerializedDAG @@ -32,6 +31,11 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models import XCom # type: ignore[no-redef] + XCOM_KEY = "test_xcom_key" CUSTOM_KEYS = { "foo": "bar", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py index 634a313dd2810..3e2d01405dd64 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py @@ -21,13 +21,27 @@ BatchJobDetailsLink, BatchJobQueueLink, ) +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestBatchJobDefinitionLink(BaseAwsLinksTestCase): link_class = BatchJobDefinitionLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_definition_arn": "arn:fake:jd", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/batch/home" @@ -42,7 +56,17 @@ def test_extra_link(self): class TestBatchJobDetailsLink(BaseAwsLinksTestCase): link_class = BatchJobDetailsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "cn-north-1", + "aws_domain": self.link_class.get_aws_domain("aws-cn"), + "aws_partition": "aws-cn", + "job_id": "fake-id", + }, + ) self.assert_extra_link_url( expected_url="https://console.amazonaws.cn/batch/home?region=cn-north-1#jobs/detail/fake-id", region_name="cn-north-1", @@ -54,7 +78,17 @@ def test_extra_link(self): class TestBatchJobQueueLink(BaseAwsLinksTestCase): link_class = BatchJobQueueLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_queue_arn": "arn:fake:jq", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/batch/home?region=us-east-1#queues/detail/arn:fake:jq" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py index e00b1cabb6f06..639d8222b59f6 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py @@ -20,14 +20,28 @@ ComprehendDocumentClassifierLink, ComprehendPiiEntitiesDetectionLink, ) +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestComprehendPiiEntitiesDetectionLink(BaseAwsLinksTestCase): link_class = ComprehendPiiEntitiesDetectionLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): test_job_id = "123-345-678" + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_id": test_job_id, + }, + ) self.assert_extra_link_url( expected_url=( f"https://console.aws.amazon.com/comprehend/home?region=eu-west-1#/analysis-job-details/pii/{test_job_id}" @@ -41,10 +55,20 @@ def test_extra_link(self): class TestComprehendDocumentClassifierLink(BaseAwsLinksTestCase): link_class = ComprehendDocumentClassifierLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): test_job_id = ( "arn:aws:comprehend:us-east-1:0123456789:document-classifier/test-custom-document-classifier" ) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "arn": test_job_id, + }, + ) self.assert_extra_link_url( expected_url=( f"https://console.aws.amazon.com/comprehend/home?region=us-east-1#classifier-version-details/{test_job_id}" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py index bfb0e62733dcf..4a92c08d87ca0 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py @@ -17,8 +17,12 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.datasync import DataSyncTaskExecutionLink, DataSyncTaskLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + TASK_ID = "task-0b36221bf94ad2bdd" EXECUTION_ID = "exec-00000000000000004" @@ -26,8 +30,18 @@ class TestDataSyncTaskLink(BaseAwsLinksTestCase): link_class = DataSyncTaskLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): task_id = TASK_ID + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "task_id": task_id, + }, + ) self.assert_extra_link_url( expected_url=(f"https://console.aws.amazon.com/datasync/home?region=us-east-1#/tasks/{TASK_ID}"), region_name="us-east-1", @@ -39,7 +53,18 @@ def test_extra_link(self): class TestDataSyncTaskExecutionLink(BaseAwsLinksTestCase): link_class = DataSyncTaskExecutionLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "task_id": TASK_ID, + "task_execution_id": EXECUTION_ID, + }, + ) self.assert_extra_link_url( expected_url=( f"https://console.aws.amazon.com/datasync/home?region=us-east-1#/history/{TASK_ID}/{EXECUTION_ID}" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py index 279127484351b..8aefd8a412256 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py @@ -17,15 +17,29 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestEC2InstanceLink(BaseAwsLinksTestCase): link_class = EC2InstanceLink INSTANCE_ID = "i-xxxxxxxxxxxx" - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "instance_id": self.INSTANCE_ID, + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/ec2/home" @@ -48,8 +62,18 @@ def test_instance_id_filter(self): result = EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS) assert result == instance_list - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): instance_list = ",:".join(self.INSTANCE_IDS) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "instance_ids": EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS), + }, + ) self.assert_extra_link_url( expected_url=(f"{self.BASE_URL}?region=eu-west-1#Instances:instanceId=:{instance_list}"), region_name="eu-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py index 412398d5bc777..00ff0adabb914 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py @@ -32,13 +32,27 @@ get_log_uri, get_serverless_dashboard_url, ) +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestEmrClusterLink(BaseAwsLinksTestCase): link_class = EmrClusterLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_flow_id": "j-TEST-FLOW-ID", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/emr/home?region=us-west-1#/clusterDetails/j-TEST-FLOW-ID" @@ -65,7 +79,17 @@ def test_get_log_uri(cluster_info, expected_uri): class TestEmrLogsLink(BaseAwsLinksTestCase): link_class = EmrLogsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-2", + "aws_domain": self.link_class.get_aws_domain("aws"), + "log_uri": "myLogUri/", + "job_flow_id": "j-8989898989", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/s3/buckets/myLogUri/?region=eu-west-2&prefix=j-8989898989/" @@ -97,10 +121,20 @@ def mocked_emr_serverless_hook(): class TestEmrServerlessLogsLink(BaseAwsLinksTestCase): link_class = EmrServerlessLogsLink - def test_extra_link(self, mocked_emr_serverless_hook): + def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client = mocked_emr_serverless_hook.return_value.conn mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "conn_id": "aws-test", + "application_id": "app-id", + "job_run_id": "job-run-id", + }, + ) + self.assert_extra_link_url( expected_url="https://example.com/logs/SPARK_DRIVER/stdout.gz?authToken=1234", conn_id="aws-test", @@ -120,10 +154,19 @@ def test_extra_link(self, mocked_emr_serverless_hook): class TestEmrServerlessDashboardLink(BaseAwsLinksTestCase): link_class = EmrServerlessDashboardLink - def test_extra_link(self, mocked_emr_serverless_hook): + def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client = mocked_emr_serverless_hook.return_value.conn mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "conn_id": "aws-test", + "application_id": "app-id", + "job_run_id": "job-run-id", + }, + ) self.assert_extra_link_url( expected_url="https://example.com/?authToken=1234", conn_id="aws-test", @@ -208,7 +251,19 @@ def test_get_serverless_dashboard_url_parameters(): class TestEmrServerlessS3LogsLink(BaseAwsLinksTestCase): link_class = EmrServerlessS3LogsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "log_uri": "s3://bucket-name/logs/", + "application_id": "app-id", + "job_run_id": "job-run-id", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/s3/buckets/bucket-name?region=us-west-1&prefix=logs/applications/app-id/jobs/job-run-id/" @@ -224,7 +279,20 @@ def test_extra_link(self): class TestEmrServerlessCloudWatchLogsLink(BaseAwsLinksTestCase): link_class = EmrServerlessCloudWatchLogsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "awslogs_group": "/aws/emrs", + "stream_prefix": "some-prefix", + "application_id": "app-id", + "job_run_id": "job-run-id", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/cloudwatch/home?region=us-west-1#logsV2:log-groups/log-group/%2Faws%2Femrs$3FlogStreamNameFilter$3Dsome-prefix" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py index 87f34ab26a6c4..5d05c397bdfdd 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py @@ -17,13 +17,28 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestGlueJobRunDetailsLink(BaseAwsLinksTestCase): link_class = GlueJobRunDetailsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "ap-southeast-2", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_run_id": "11111", + "job_name": "test_job_name", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/gluestudio/home" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py index e6df3705e0309..ea6c07f269e97 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py @@ -17,13 +17,29 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestCloudWatchEventsLink(BaseAwsLinksTestCase): link_class = CloudWatchEventsLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "awslogs_region": "ap-southeast-2", + "awslogs_group": "/test/logs/group", + "awslogs_stream_name": "test/stream/d56a66bb98a14c4593defa1548686edf", + }, + ) self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/cloudwatch/home" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py index f4681c52a48dc..568bfe1d780cf 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py @@ -16,14 +16,29 @@ # under the License. from __future__ import annotations +from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase): link_class = SageMakerTransformJobLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="sagemaker_transform_job_details", + value={ + "region_name": "us-east-1", + "aws_domain": BaseAwsLink.get_aws_domain("aws"), + **{"job_name": "test_job_name"}, + }, + ) + self.assert_extra_link_url( expected_url=( "https://console.aws.amazon.com/sagemaker/home" diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py index c55d1231fd83f..af7fbcf03a09a 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py @@ -17,13 +17,27 @@ from __future__ import annotations from airflow.providers.amazon.aws.links.sagemaker_unified_studio import SageMakerUnifiedStudioLink +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestSageMakerUnifiedStudioLink(BaseAwsLinksTestCase): link_class = SageMakerUnifiedStudioLink - def test_extra_link(self): + def test_extra_link(self, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "us-east-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "job_name": "test_job_name", + }, + ) self.assert_extra_link_url( expected_url=("https://console.aws.amazon.com/datazone/home?region=us-east-1"), region_name="us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py index ea2b5e0edfd53..00c50fc3d70f1 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py @@ -22,8 +22,12 @@ StateMachineDetailsLink, StateMachineExecutionsDetailsLink, ) +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + class TestStateMachineDetailsLink(BaseAwsLinksTestCase): link_class = StateMachineDetailsLink @@ -40,7 +44,17 @@ class TestStateMachineDetailsLink(BaseAwsLinksTestCase): ), ], ) - def test_extra_link(self, state_machine_arn, expected_url: str): + def test_extra_link(self, state_machine_arn, expected_url: str, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "state_machine_arn": state_machine_arn, + }, + ) self.assert_extra_link_url( expected_url=expected_url, region_name="eu-west-1", @@ -65,7 +79,17 @@ class TestStateMachineExecutionsDetailsLink(BaseAwsLinksTestCase): ), ], ) - def test_extra_link(self, execution_arn, expected_url: str): + def test_extra_link(self, execution_arn, expected_url: str, mock_supervisor_comms): + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key=self.link_class.key, + value={ + "region_name": "eu-west-1", + "aws_domain": self.link_class.get_aws_domain("aws"), + "aws_partition": "aws", + "execution_arn": execution_arn, + }, + ) self.assert_extra_link_url( expected_url=expected_url, region_name="eu-west-1", diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index fd38de09f37ac..75ec45f94b84d 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -36,7 +36,6 @@ TaskDeferred, ) from airflow.models import DAG, DagModel, DagRun, TaskInstance -from airflow.models.xcom import XCom from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.operators.pod import ( KubernetesPodOperator, @@ -57,6 +56,11 @@ from tests_common.test_utils import db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom +else: + from airflow.models.xcom import XCom # type: ignore[no-redef] + pytestmark = pytest.mark.db_test @@ -1305,8 +1309,18 @@ def test_push_xcom_pod_info( ) pod, _ = self.run_pod(k) - pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_name") - pod_namespace = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_namespace") + if AIRFLOW_V_3_0_PLUS: + pod_name = XCom.get_many(run_id=self.dag_run.run_id, task_ids="task", key="pod_name").first() + pod_namespace = XCom.get_many( + run_id=self.dag_run.run_id, task_ids="task", key="pod_namespace" + ).first() + + pod_name = XCom.deserialize_value(pod_name) + pod_namespace = XCom.deserialize_value(pod_namespace) + else: + pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_name") + pod_namespace = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_namespace") + assert pod_name == pod.metadata.name assert pod_namespace == pod.metadata.namespace diff --git a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py index addde4c6c10dc..c318698af3d4e 100644 --- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py @@ -28,14 +28,18 @@ from airflow.configuration import conf from airflow.io.path import ObjectStoragePath -from airflow.models.xcom import BaseXCom from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.json import XComDecoder, XComEncoder if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models import XCom + from airflow.sdk.execution_time.comms import XComResult + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.xcom import BaseXCom +else: + from airflow.models.xcom import BaseXCom # type: ignore[no-redef] T = TypeVar("T") @@ -149,7 +153,7 @@ def serialize_value( # type: ignore[override] return BaseXCom.serialize_value(str(p)) @staticmethod - def deserialize_value(result: XCom) -> Any: + def deserialize_value(result) -> Any: """ Deserializes the value from the database or object storage. @@ -167,7 +171,7 @@ def deserialize_value(result: XCom) -> Any: return data @staticmethod - def purge(xcom: XCom, session: Session) -> None: + def purge(xcom: XComResult, session: Session | None = None) -> None: # type: ignore[override] if not isinstance(xcom.value, str): return with contextlib.suppress(TypeError, ValueError): diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py b/providers/common/io/tests/unit/common/io/xcom/test_backend.py index da5df7feb3c40..802106024b887 100644 --- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py +++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py @@ -20,7 +20,7 @@ import pytest import airflow.models.xcom -from airflow.models.xcom import BaseXCom, resolve_xcom_backend +from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone @@ -31,6 +31,13 @@ pytestmark = [pytest.mark.db_test] +if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel + from airflow.sdk.execution_time.comms import XComResult + from airflow.sdk.execution_time.xcom import resolve_xcom_backend +else: + from airflow.models.xcom import BaseXCom, resolve_xcom_backend # type: ignore[no-redef] + @pytest.fixture(autouse=True) def reset_db(): @@ -80,7 +87,7 @@ def setup_test_cases(self, tmp_path): with conf_vars(configuration): yield - def test_value_db(self, task_instance, session): + def test_value_db(self, task_instance, mock_supervisor_comms, session): session.add(task_instance) session.commit() XCom = resolve_xcom_backend() @@ -92,26 +99,20 @@ def test_value_db(self, task_instance, session): dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="return_value", value={"key": "value"} + ) + value = XCom.get_value( key=XCOM_RETURN_KEY, ti_key=task_instance.key, - session=session, ) assert value == {"key": "value"} - qry = XCom.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) - assert qry.first().value == {"key": "value"} - - def test_value_storage(self, task_instance, session): + def test_value_storage(self, task_instance, mock_supervisor_comms, session): session.add(task_instance) session.commit() XCom = resolve_xcom_backend() @@ -123,42 +124,77 @@ def test_value_storage(self, task_instance, session): dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - res = ( - XCom.get_many( + if AIRFLOW_V_3_0_PLUS: + XComModel.set( key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, + value=self.path, + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - .with_entities(BaseXCom.value) - .first() - ) - data = BaseXCom.deserialize_value(res) + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() + ) + data = XComModel.deserialize_value(res) + else: + res = ( + XCom.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(BaseXCom.value) + .first() + ) + data = BaseXCom.deserialize_value(res) + p = XComObjectStorageBackend._get_full_path(data) assert p.exists() is True + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key=XCOM_RETURN_KEY, value={"key": "bigvaluebigvaluebigvalue" * 100} + ) + value = XCom.get_value( key=XCOM_RETURN_KEY, ti_key=task_instance.key, - session=session, ) assert value == {"key": "bigvaluebigvaluebigvalue" * 100} - qry = XCom.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) - assert str(p) == qry.first().value + if AIRFLOW_V_3_0_PLUS: + qry = XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + assert str(p) == XComModel.deserialize_value(qry.first()) + else: + qry = XCom.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + assert str(p) == qry.first().value - def test_clear(self, task_instance, session): + def test_clear(self, task_instance, session, mock_supervisor_comms): session.add(task_instance) session.commit() XCom = resolve_xcom_backend() @@ -170,52 +206,103 @@ def test_clear(self, task_instance, session): dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - res = ( - XCom.get_many( + if AIRFLOW_V_3_0_PLUS: + path = mock_supervisor_comms.send_request.call_args_list[-1].kwargs["msg"].value + XComModel.set( key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, + value=path, + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - .with_entities(BaseXCom.value) - .first() - ) - data = BaseXCom.deserialize_value(res) + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() + ) + data = XComModel.deserialize_value(res) + else: + res = ( + XCom.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(BaseXCom.value) + .first() + ) + data = BaseXCom.deserialize_value(res) p = XComObjectStorageBackend._get_full_path(data) assert p.exists() is True + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} + ) value = XCom.get_value( key=XCOM_RETURN_KEY, ti_key=task_instance.key, - session=session, ) assert value - XCom.clear( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult(key=XCOM_RETURN_KEY, value=path) + XCom.delete( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + run_id=task_instance.run_id, + key=XCOM_RETURN_KEY, + map_index=task_instance.map_index, + ) + XComModel.clear( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + run_id=task_instance.run_id, + map_index=task_instance.map_index, + ) + value = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() + ) + else: + XCom.clear( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + value = XCom.get_value( + key=XCOM_RETURN_KEY, + ti_key=task_instance.key, + session=session, + ) assert p.exists() is False - - value = XCom.get_value( - key=XCOM_RETURN_KEY, - ti_key=task_instance.key, - session=session, - ) assert not value @conf_vars({("common.io", "xcom_objectstorage_compression"): "gzip"}) - def test_compression(self, task_instance, session): + def test_compression(self, task_instance, session, mock_supervisor_comms): session.add(task_instance) session.commit() + XCom = resolve_xcom_backend() airflow.models.xcom.XCom = XCom @@ -225,30 +312,53 @@ def test_compression(self, task_instance, session): dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - res = ( - XCom.get_many( + if AIRFLOW_V_3_0_PLUS: + XComModel.set( key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, + value=self.path + ".gz", + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) - .with_entities(BaseXCom.value) - .first() - ) - data = BaseXCom.deserialize_value(res) - p = XComObjectStorageBackend._get_full_path(data) - assert p.exists() is True - assert p.suffix == ".gz" + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() + ) + data = XComModel.deserialize_value(res) + else: + res = ( + XCom.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(BaseXCom.value) + .first() + ) + data = BaseXCom.deserialize_value(res) + + assert data.endswith(".gz") + + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} + ) value = XCom.get_value( key=XCOM_RETURN_KEY, ti_key=task_instance.key, - session=session, ) assert value == {"key": "superlargevalue" * 100} diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 3f4ceefda1ce2..28c18a32cc55d 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -26,7 +26,7 @@ from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.models import Connection, DagRun, TaskInstance as TI, XCom +from airflow.models import Connection, DagRun, TaskInstance as TI from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( BaseSQLOperator, @@ -50,7 +50,10 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom from airflow.utils.types import DagRunTriggeredByType +else: + from airflow.models.xcom import XCom # type: ignore[no-redef] pytestmark = pytest.mark.db_test diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 5888db6ad7cf9..2de7de67b3e94 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -29,7 +29,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, XCom +from airflow.models import BaseOperator from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, @@ -50,7 +50,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] DEFER_METHOD_NAME = "execute_complete" diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 3dc8646b9b30b..85a884a53e595 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -31,7 +31,6 @@ from airflow.models.dag import DAG, clear_task_instances from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceKey -from airflow.models.xcom import XCom from airflow.plugins_manager import AirflowPlugin from airflow.providers.databricks.hooks.databricks import DatabricksHook from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS @@ -53,7 +52,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py index 706973dbc55c5..68e06133fd627 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any from airflow.configuration import conf -from airflow.models import BaseOperator, XCom +from airflow.models import BaseOperator from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.dbt.cloud.hooks.dbt import ( DbtCloudHook, @@ -41,7 +41,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py index c8d0b06c2863e..9e8408c56955c 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py @@ -31,8 +31,12 @@ DbtCloudRunJobOperator, ) from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger +from airflow.providers.dbt.cloud.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import db, timezone +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2021, 1, 1) @@ -641,7 +645,9 @@ def test_custom_trigger_reason(self, mock_run_job, conn_id, account_id): [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], ids=["default_account", "explicit_account"], ) - def test_run_job_operator_link(self, conn_id, account_id, create_task_instance_of_operator, request): + def test_run_job_operator_link( + self, conn_id, account_id, create_task_instance_of_operator, request, mock_supervisor_comms + ): ti = create_task_instance_of_operator( DbtCloudRunJobOperator, dag_id="test_dbt_cloud_run_job_op_link", @@ -658,6 +664,15 @@ def test_run_job_operator_link(self, conn_id, account_id, create_task_instance_o ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"]) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="job_run_url", + value=EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( + account_id=account_id or DEFAULT_ACCOUNT_ID, + project_id=PROJECT_ID, + run_id=_run_response["data"]["id"], + ), + ) url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) assert url == ( diff --git a/providers/fab/src/airflow/providers/fab/www/auth.py b/providers/fab/src/airflow/providers/fab/www/auth.py index 5ca881e1f8815..cd6998c1e35a5 100644 --- a/providers/fab/src/airflow/providers/fab/www/auth.py +++ b/providers/fab/src/airflow/providers/fab/www/auth.py @@ -53,7 +53,7 @@ ) from airflow.models import DagRun, Pool, TaskInstance, Variable from airflow.models.connection import Connection - from airflow.models.xcom import BaseXCom + from airflow.models.xcom import XComModel T = TypeVar("T", bound=Callable) @@ -249,7 +249,7 @@ def has_access_dag_entities(method: ResourceMethod, access_entity: DagAccessEnti def has_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): - items: set[BaseXCom | DagRun | TaskInstance] = set(args[1]) + items: set[XComModel | DagRun | TaskInstance] = set(args[1]) requests: Sequence[IsAuthorizedDagRequest] = [ { "method": method, diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index 49c3e09b29e10..3bf9114cc9b9f 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, ClassVar -from airflow.models import XCom from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -28,8 +27,10 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.models.xcom import XCom # type: ignore[no-redef] BASE_LINK = "https://console.cloud.google.com" diff --git a/providers/google/src/airflow/providers/google/cloud/links/datafusion.py b/providers/google/src/airflow/providers/google/cloud/links/datafusion.py index 117be86825c6b..c3e2351e2adc0 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/links/datafusion.py @@ -21,12 +21,13 @@ from typing import TYPE_CHECKING, ClassVar -from airflow.models import XCom from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] if TYPE_CHECKING: diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py index 657ee717361a0..72dfef308e628 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py @@ -25,7 +25,6 @@ import attr from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.models import XCom from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS @@ -36,7 +35,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index 64b3128560e41..72d4e7fc160a7 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -32,7 +32,6 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.xcom import XCom from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.links.storage import StorageLink @@ -48,7 +47,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py index cb7432dba90b9..9c055889bd751 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py @@ -18,9 +18,8 @@ from typing import TYPE_CHECKING, ClassVar -from airflow.models import BaseOperator, XCom - if TYPE_CHECKING: + from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -28,7 +27,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/google/tests/unit/google/cloud/links/test_dataplex.py b/providers/google/tests/unit/google/cloud/links/test_dataplex.py index d0a08c75c1e3d..3ec2905751390 100644 --- a/providers/google/tests/unit/google/cloud/links/test_dataplex.py +++ b/providers/google/tests/unit/google/cloud/links/test_dataplex.py @@ -43,6 +43,10 @@ DataplexCreateTaskOperator, DataplexListTasksOperator, ) +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult TEST_LOCATION = "test-location" TEST_PROJECT_ID = "test-project-id" @@ -101,7 +105,7 @@ class TestDataplexTaskLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_TASK_LINK link = DataplexTaskLink() ti = create_task_instance_of_operator( @@ -117,13 +121,24 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "lake_id": ti.task.lake_id, + "task_id": ti.task.dataplex_task_id, + "region": ti.task.region, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexTasksLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_TASKS_LINK link = DataplexTasksLink() ti = create_task_instance_of_operator( @@ -137,13 +152,22 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "project_id": ti.task.project_id, + "lake_id": ti.task.lake_id, + "region": ti.task.region, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexLakeLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = DATAPLEX_LAKE_LINK link = DataplexLakeLink() ti = create_task_instance_of_operator( @@ -158,13 +182,22 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "lake_id": ti.task.lake_id, + "region": ti.task.region, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogEntryGroupLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUP_LINK link = DataplexCatalogEntryGroupLink() ti = create_task_instance_of_operator( @@ -178,13 +211,22 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "entry_group_id": ti.task.entry_group_id, + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogEntryGroupsLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_GROUPS_LINK link = DataplexCatalogEntryGroupsLink() ti = create_task_instance_of_operator( @@ -199,13 +241,21 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogEntryTypeLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_TYPE_LINK link = DataplexCatalogEntryTypeLink() ti = create_task_instance_of_operator( @@ -219,13 +269,22 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "entry_type_id": ti.task.entry_type_id, + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogEntryTypesLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_TYPES_LINK link = DataplexCatalogEntryTypesLink() ti = create_task_instance_of_operator( @@ -240,13 +299,21 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogAspectTypeLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ASPECT_TYPE_LINK link = DataplexCatalogAspectTypeLink() ti = create_task_instance_of_operator( @@ -260,13 +327,22 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "aspect_type_id": ti.task.aspect_type_id, + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogAspectTypesLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ASPECT_TYPES_LINK link = DataplexCatalogAspectTypesLink() ti = create_task_instance_of_operator( @@ -281,13 +357,21 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestDataplexCatalogEntryLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = EXPECTED_DATAPLEX_CATALOG_ENTRY_LINK link = DataplexCatalogEntryLink() ti = create_task_instance_of_operator( @@ -302,5 +386,15 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "entry_id": ti.task.entry_id, + "entry_group_id": ti.task.entry_group_id, + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py index 1d3822ad32d3e..ddfc3e205e8e2 100644 --- a/providers/google/tests/unit/google/cloud/links/test_translate.py +++ b/providers/google/tests/unit/google/cloud/links/test_translate.py @@ -19,6 +19,8 @@ import pytest +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS + # For no Pydantic environment, we need to skip the tests pytest.importorskip("google.cloud.aiplatform_v1") @@ -35,6 +37,9 @@ AutoMLTrainModelOperator, ) +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + GCP_LOCATION = "test-location" GCP_PROJECT_ID = "test-project" DATASET = "test-dataset" @@ -43,7 +48,7 @@ class TestTranslationLegacyDatasetLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}" link = TranslationLegacyDatasetLink() ti = create_task_instance_of_operator( @@ -56,13 +61,18 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, dataset_id=DATASET, project_id=GCP_PROJECT_ID) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={"location": ti.task.location, "dataset_id": DATASET, "project_id": GCP_PROJECT_ID}, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestTranslationDatasetListLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}" link = TranslationDatasetListLink() ti = create_task_instance_of_operator( @@ -74,13 +84,20 @@ def test_get_link(self, create_task_instance_of_operator, session): session.add(ti) session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, project_id=GCP_PROJECT_ID) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "project_id": GCP_PROJECT_ID, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestTranslationLegacyModelLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = ( f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}" @@ -103,13 +120,23 @@ def test_get_link(self, create_task_instance_of_operator, session): model_id=MODEL, project_id=GCP_PROJECT_ID, ) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "location": ti.task.location, + "dataset_id": DATASET, + "model_id": MODEL, + "project_id": GCP_PROJECT_ID, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url class TestTranslationLegacyModelTrainLink: @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): expected_url = ( f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" f"train?project={GCP_PROJECT_ID}" @@ -130,5 +157,14 @@ def test_get_link(self, create_task_instance_of_operator, session): task_instance=ti.task, project_id=GCP_PROJECT_ID, ) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={ + "location": ti.task.location, + "dataset_id": ti.task.model["dataset_id"], + "project_id": GCP_PROJECT_ID, + }, + ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 7eacfcf4ca415..73f3e6ce848d2 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -76,6 +76,9 @@ from tests_common.test_utils.db import clear_db_runs, clear_db_xcom from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + AIRFLOW_VERSION_LABEL = "v" + str(AIRFLOW_VERSION).replace(".", "-").replace("+", "-") cluster_params = inspect.signature(ClusterGenerator.__init__).parameters @@ -1110,7 +1113,9 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock @pytest.mark.db_test @pytest.mark.need_serialized_dag -def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): +def test_create_cluster_operator_extra_links( + dag_maker, create_task_instance_of_operator, mock_supervisor_comms +): ti = create_task_instance_of_operator( DataprocCreateClusterOperator, dag_id=TEST_DAG_ID, @@ -1128,11 +1133,21 @@ def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_ operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value="", + ) # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value={"cluster_id": "cluster_name", "project_id": "test-project", "region": "test-location"}, + ) # Assert operator links after execution assert ( ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) @@ -1211,7 +1226,9 @@ def test_execute(self, mock_hook): @pytest.mark.db_test @pytest.mark.need_serialized_dag -def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): +def test_scale_cluster_operator_extra_links( + dag_maker, create_task_instance_of_operator, mock_supervisor_comms +): ti = create_task_instance_of_operator( DataprocScaleClusterOperator, dag_id=TEST_DAG_ID, @@ -1232,6 +1249,11 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc resource" + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value="", + ) # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" @@ -1240,6 +1262,12 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o value=DATAPROC_CLUSTER_CONF_EXPECTED, ) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="key", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + ) + # Assert operator links after execution assert ( ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) @@ -2076,7 +2104,9 @@ def test_missing_region_parameter(self): @pytest.mark.db_test @pytest.mark.need_serialized_dag @mock.patch(DATAPROC_PATH.format("DataprocHook")) -def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_instance_of_operator): +def test_submit_job_operator_extra_links( + mock_hook, dag_maker, create_task_instance_of_operator, mock_supervisor_comms +): mock_hook.return_value.project_id = GCP_PROJECT ti = create_task_instance_of_operator( DataprocSubmitJobOperator, @@ -2095,11 +2125,23 @@ def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_insta operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Job" + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_job", + value="", + ) + # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_job", + value=DATAPROC_JOB_EXPECTED, + ) + # Assert operator links after execution assert ( ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) @@ -2276,7 +2318,9 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock @pytest.mark.db_test @pytest.mark.need_serialized_dag -def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_operator): +def test_update_cluster_operator_extra_links( + dag_maker, create_task_instance_of_operator, mock_supervisor_comms +): ti = create_task_instance_of_operator( DataprocUpdateClusterOperator, dag_id=TEST_DAG_ID, @@ -2297,11 +2341,22 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_ operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_cluster", + value="", + ) # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) + # Assert operator links after execution assert ( ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) @@ -2492,7 +2547,9 @@ def test_on_kill(self, mock_hook): @pytest.mark.db_test @pytest.mark.need_serialized_dag @mock.patch(DATAPROC_PATH.format("DataprocHook")) -def test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_task_instance_of_operator): +def test_instantiate_workflow_operator_extra_links( + mock_hook, dag_maker, create_task_instance_of_operator, mock_supervisor_comms +): mock_hook.return_value.project_id = GCP_PROJECT ti = create_task_instance_of_operator( DataprocInstantiateWorkflowTemplateOperator, @@ -2510,11 +2567,20 @@ def test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_ operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_workflow", + value="", + ) # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) - + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_workflow", + value=DATAPROC_WORKFLOW_EXPECTED, + ) # Assert operator links after execution assert ( ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) @@ -3155,7 +3221,7 @@ def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_access @pytest.mark.need_serialized_dag @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_instantiate_inline_workflow_operator_extra_links( - mock_hook, dag_maker, create_task_instance_of_operator + mock_hook, dag_maker, create_task_instance_of_operator, mock_supervisor_comms ): mock_hook.return_value.project_id = GCP_PROJECT ti = create_task_instance_of_operator( @@ -3173,11 +3239,19 @@ def test_instantiate_inline_workflow_operator_extra_links( deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" - + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_workflow", + value="", + ) # Assert operator link is empty when no XCom push occurred assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) + if AIRFLOW_V_3_0_PLUS: + mock_supervisor_comms.get_message.return_value = XComResult( + key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED + ) # Assert operator links after execution assert ( diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py index 9f3fb298e8f13..71ca318b63652 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py @@ -25,7 +25,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, XCom +from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -43,7 +43,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py index d8052f66d4fc3..1a096baece470 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py @@ -23,7 +23,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, XCom +from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.synapse import ( AzureSynapseHook, AzureSynapsePipelineHook, @@ -42,7 +42,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py index f1002f80ffcc8..f59355565dc62 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py @@ -43,6 +43,9 @@ if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + DEFAULT_DATE = timezone.datetime(2021, 1, 1) SUBSCRIPTION_ID = "my-subscription-id" TASK_ID = "run_pipeline_op" @@ -235,7 +238,9 @@ def test_execute_no_wait_for_termination(self, mock_run_pipeline): (None, None), ], ) - def test_run_pipeline_operator_link(self, resource_group, factory, create_task_instance_of_operator): + def test_run_pipeline_operator_link( + self, resource_group, factory, create_task_instance_of_operator, mock_supervisor_comms + ): ti = create_task_instance_of_operator( AzureDataFactoryRunPipelineOperator, dag_id="test_adf_run_pipeline_op_link", @@ -246,6 +251,13 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i factory_name=factory, ) ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) + + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="run_id", + value=PIPELINE_RUN_RESPONSE["run_id"], + ) + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = ( "https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py index 9c327e98b0fad..45c538c2c400c 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py @@ -33,8 +33,12 @@ AzureSynapseRunPipelineOperator, AzureSynapseRunSparkBatchOperator, ) +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + DEFAULT_DATE = timezone.datetime(2021, 1, 1) SUBSCRIPTION_ID = "subscription_id" TENANT_ID = "tenant_id" @@ -276,7 +280,7 @@ def test_execute_no_wait_for_termination(self, mock_run_pipeline): mock_get_pipeline_run.assert_not_called() @pytest.mark.db_test - def test_run_pipeline_operator_link(self, create_task_instance_of_operator): + def test_run_pipeline_operator_link(self, create_task_instance_of_operator, mock_supervisor_comms): ti = create_task_instance_of_operator( AzureSynapseRunPipelineOperator, dag_id="test_synapse_run_pipeline_op_link", @@ -287,6 +291,11 @@ def test_run_pipeline_operator_link(self, create_task_instance_of_operator): ) ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = XComResult( + key="run_id", + value=PIPELINE_RUN_RESPONSE["run_id"], + ) url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 8084f975c28e8..e00468d1bc42b 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -38,7 +38,6 @@ from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.xcom import XCom from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone @@ -63,7 +62,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py b/providers/standard/tests/unit/standard/operators/test_weekday.py index a2e19cffd304e..2fbbe8271d530 100644 --- a/providers/standard/tests/unit/standard/operators/test_weekday.py +++ b/providers/standard/tests/unit/standard/operators/test_weekday.py @@ -25,17 +25,20 @@ from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance as TI -from airflow.models.xcom import XCom from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.weekday import BranchDayOfWeekOperator from airflow.providers.standard.utils.skipmixin import XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.weekday import WeekDay -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if AIRFLOW_V_3_0_PLUS: + from airflow.models.xcom import XComModel as XCom +else: + from airflow.models.xcom import XCom # type: ignore[no-redef] pytestmark = pytest.mark.db_test diff --git a/providers/yandex/src/airflow/providers/yandex/links/yq.py b/providers/yandex/src/airflow/providers/yandex/links/yq.py index 72305fd6acaa9..49a42473ca446 100644 --- a/providers/yandex/src/airflow/providers/yandex/links/yq.py +++ b/providers/yandex/src/airflow/providers/yandex/links/yq.py @@ -18,8 +18,6 @@ from typing import TYPE_CHECKING -from airflow.models import XCom - if TYPE_CHECKING: from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey @@ -34,7 +32,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom else: + from airflow.models import XCom # type: ignore[no-redef] from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] XCOM_WEBLINK_KEY = "web_link" diff --git a/providers/yandex/tests/unit/yandex/links/test_yq.py b/providers/yandex/tests/unit/yandex/links/test_yq.py index 444ae762ac264..9bfa0fbede6ff 100644 --- a/providers/yandex/tests/unit/yandex/links/test_yq.py +++ b/providers/yandex/tests/unit/yandex/links/test_yq.py @@ -21,12 +21,16 @@ import pytest from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import XCom from airflow.providers.yandex.links.yq import YQLink from tests_common.test_utils.mock_operators import MockOperator from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models import XCom # type: ignore[no-redef] + yandexcloud = pytest.importorskip("yandexcloud") diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index a28ea40a79182..d5c02b1f7b6ea 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -333,6 +333,20 @@ def set( # decouple from the server response string return {"ok": True} + def delete( + self, + dag_id: str, + run_id: str, + task_id: str, + key: str, + ) -> dict[str, bool]: + """Delete a XCom with given key via the API server.""" + self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") + # Any error from the server will anyway be propagated down to the supervisor, + # so we choose to send a generic response to the supervisor over the server response to + # decouple from the server response string + return {"ok": True} + class AssetOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 19dda0a681cb4..622f3d0a74e72 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -377,6 +377,7 @@ class TIRunContext(BaseModel): upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None next_method: Annotated[str | None, Field(title="Next Method")] = None next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None + xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None class TITerminalStatePayload(BaseModel): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 79fe87c9caa62..ffe73092adfd3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -335,6 +335,14 @@ class SetXCom(BaseModel): type: Literal["SetXCom"] = "SetXCom" +class DeleteXCom(BaseModel): + key: str + dag_id: str + run_id: str + task_id: str + type: Literal["DeleteXCom"] = "DeleteXCom" + + class GetConnection(BaseModel): conn_id: str type: Literal["GetConnection"] = "GetConnection" @@ -408,6 +416,7 @@ class GetPrevSuccessfulDagRun(BaseModel): SetXCom, TaskState, RuntimeCheckOnTask, + DeleteXCom, ], Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 6ab1d2e4e1809..b7706e7735e43 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -65,6 +65,7 @@ AssetResult, ConnectionResult, DeferTask, + DeleteXCom, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, @@ -900,6 +901,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.client.xcoms.set( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length ) + elif isinstance(msg, DeleteXCom): + self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, SetRenderedFields): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 86bd8c953f62f..15cc8de3a2d47 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -52,19 +52,16 @@ from airflow.sdk.definitions.param import process_params from airflow.sdk.execution_time.comms import ( DeferTask, - GetXCom, OKResponse, RescheduleTask, RuntimeCheckOnTask, SetRenderedFields, - SetXCom, SkipDownstreamTasks, StartupDetails, SucceedTask, TaskState, ToSupervisor, ToTask, - XComResult, ) from airflow.sdk.execution_time.context import ( ConnectionAccessor, @@ -75,6 +72,7 @@ get_previous_dagrun_success, set_current_context, ) +from airflow.sdk.execution_time.xcom import XCom from airflow.utils.net import get_hostname from airflow.utils.state import TaskInstanceState @@ -256,7 +254,9 @@ def xcom_pull( run_id: str | None = None, ) -> Any: """ - Pull XComs that optionally meet certain criteria. + Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. + + The pull can be filtered optionally by certain criterion. :param key: A key for the XCom. If provided, only XComs with matching keys will be returned. The default key is ``'return_value'``, also @@ -303,37 +303,16 @@ def xcom_pull( elif isinstance(map_indexes, Iterable): # TODO: Handle multiple map_indexes or remove support raise NotImplementedError("Multiple map_indexes are not supported yet") - - log = structlog.get_logger(logger_name="task") - xcoms = [] for t in task_ids: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=t, - run_id=run_id, - map_index=map_indexes, - ), + value = XCom.get_one( + run_id=run_id, + key=key, + task_id=t, + dag_id=dag_id, + map_index=map_indexes, ) - - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComResult): - raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") - - if msg.value is not None: - from airflow.serialization.serde import deserialize - - # TODO: Move XCom serialization & deserialization to Task SDK - # https://github.com/apache/airflow/issues/45231 - - # The execution API server deals in json compliant types now. - # serde's deserialize can handle deserializing primitive, collections, and complex objects too - xcoms.append(deserialize(msg.value)) # type: ignore[type-var] - else: - xcoms.append(default) + xcoms.append(value if value else default) if len(xcoms) == 1: return xcoms[0] @@ -358,28 +337,12 @@ def get_relevant_upstream_map_indexes( def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK # consumers - from airflow.serialization.serde import serialize - - # TODO: Move XCom serialization & deserialization to Task SDK - # https://github.com/apache/airflow/issues/45231 - - # The execution API server now deals in json compliant objects. - # It is responsibility of the client to handle any non native object serialization. - # serialize does just that. - value = serialize(value) - - log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=ti.map_index, - mapped_length=mapped_length, - ), + XCom.set( + key=key, + value=value, + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, ) @@ -596,6 +559,17 @@ def run( state: IntermediateTIState | TerminalTIState error: BaseException | None = None try: + # First, clear the xcom data sent from server + if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): + for x in keys_to_delete: + log.debug("Clearing XCom with key", key=x) + XCom.delete( + key=x, + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + ) + context = ti.get_template_context() with set_current_context(context): # This is the earliest that we can render templates -- as if it excepts for any reason we need to diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py b/task-sdk/src/airflow/sdk/execution_time/xcom.py new file mode 100644 index 0000000000000..dff2466d708fe --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +import structlog + +from airflow.configuration import conf +from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult + +log = structlog.get_logger(logger_name="task") + + +class BaseXCom: + """BaseXcom is an interface now to interact with XCom backends.""" + + @classmethod + def set( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + ) -> None: + """ + Store an XCom value. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + value = cls.serialize_value( + value=value, + key=key, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + ) + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + @classmethod + def get_value( + cls, + *, + ti_key: Any, + key: str, + ) -> Any: + """ + Retrieve an XCom value for a task instance. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + :param ti_key: The TaskInstanceKey to look up the XCom for. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + return cls.get_one( + key=key, + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + ) + + @classmethod + def _get_xcom_db_ref( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + ) -> XComResult: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + return msg + + @classmethod + def get_one( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + include_prior_dates: bool = False, + ) -> Any | None: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param include_prior_dates: If *False* (default), only XCom from the + specified DAG run is returned. If *True*, the latest matching XCom is + returned regardless of the run it belongs to. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + if msg.value is not None: + return cls.deserialize_value(msg) + return None + + @staticmethod + def serialize_value( + value: Any, + *, + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> str: + """Serialize XCom value to JSON str.""" + from airflow.serialization.serde import serialize + + # return back the value for BaseXCom, custom backends will implement this + return serialize(value) # type: ignore[return-value] + + @staticmethod + def deserialize_value(result) -> Any: + """Deserialize XCom value from str objects.""" + from airflow.serialization.serde import deserialize + + return deserialize(result.value) + + @classmethod + def purge(cls, xcom: XComResult, *args) -> None: + """Purge an XCom entry from underlying storage implementations.""" + pass + + @classmethod + def delete( + cls, + key: str, + task_id: str, + dag_id: str, + run_id: str, + map_index: int | None = None, + ) -> None: + """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + xcom_result = cls._get_xcom_db_ref( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ) + cls.purge(xcom_result) # type: ignore[call-arg] + SUPERVISOR_COMMS.send_request( + log=log, + msg=DeleteXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + ), + ) + + +def resolve_xcom_backend(): + """ + Resolve custom XCom class. + + :returns: returns the custom XCom class if configured. + """ + clazz = conf.getimport("core", "xcom_backend", fallback="airflow.sdk.execution_time.xcom.BaseXCom") + if not clazz: + return BaseXCom + if not issubclass(clazz, BaseXCom): + raise TypeError( + f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." + ) + return clazz + + +XCom = resolve_xcom_backend() diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index 4d2560106835a..89d9a94853e89 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -268,6 +268,12 @@ def mock_supervisor_comms(): yield supervisor_comms +@pytest.fixture +def mock_xcom_backend(): + with mock.patch("airflow.sdk.execution_time.task_runner.XCom", create=True) as xcom_backend: + yield xcom_backend + + @pytest.fixture def mocked_parse(spy_agency): """ diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 4d0cb91c4a716..851a9c29b7ff1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1446,6 +1446,82 @@ def execute(self, context): f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead" ) + def test_xcom_push_to_custom_xcom_backend( + self, create_runtime_ti, mock_supervisor_comms, mock_xcom_backend + ): + """Test that a task pushes a xcom to the custom xcom backend.""" + + class CustomOperator(BaseOperator): + def execute(self, context): + return "pushing to xcom backend!" + + task = CustomOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + + run(runtime_ti, log=mock.MagicMock()) + + mock_xcom_backend.set.assert_called_once_with( + key="return_value", + value="pushing to xcom backend!", + dag_id="test_dag", + task_id="pull_task", + run_id="test_run", + ) + + # assert that we didn't call the API when XCom backend is configured + assert not any( + x + == mock.call( + log=mock.ANY, + msg=SetXCom( + key="key", + value="pushing to xcom backend!", + dag_id="test_dag", + run_id="test_run", + task_id="pull_task", + map_index=-1, + ), + ) + for x in mock_supervisor_comms.send_request.call_args_list + ) + + def test_xcom_pull_from_custom_xcom_backend( + self, create_runtime_ti, mock_supervisor_comms, mock_xcom_backend + ): + """Test that a task pulls the expected XCom value if it exists, but from custom xcom backend.""" + + class CustomOperator(BaseOperator): + def execute(self, context): + value = context["ti"].xcom_pull(task_ids="pull_task", key="key") + print(f"Pulled XCom Value: {value}") + + task = CustomOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + run(runtime_ti, log=mock.MagicMock()) + + mock_xcom_backend.get_one.assert_called_once_with( + key="key", + dag_id="test_dag", + task_id="pull_task", + run_id="test_run", + map_index=-1, + ) + + assert not any( + x + == mock.call( + log=mock.ANY, + msg=GetXCom( + key="key", + dag_id="test_dag", + run_id="test_run", + task_id="pull_task", + map_index=-1, + ), + ) + for x in mock_supervisor_comms.send_request.call_args_list + ) + class TestDagParamRuntime: DEFAULT_ARGS = { diff --git a/tests/api_fastapi/core_api/routes/public/test_extra_links.py b/tests/api_fastapi/core_api/routes/public/test_extra_links.py index 547673e4d2926..4f2f84c525131 100644 --- a/tests/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/tests/api_fastapi/core_api/routes/public/test_extra_links.py @@ -22,7 +22,7 @@ from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dagbag import DagBag -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel as XCom from airflow.plugins_manager import AirflowPlugin from airflow.utils import timezone from airflow.utils.state import DagRunState diff --git a/tests/api_fastapi/core_api/routes/public/test_xcom.py b/tests/api_fastapi/core_api/routes/public/test_xcom.py index 811e4551abf69..669a5d547b48f 100644 --- a/tests/api_fastapi/core_api/routes/public/test_xcom.py +++ b/tests/api_fastapi/core_api/routes/public/test_xcom.py @@ -16,17 +16,18 @@ # under the License. from __future__ import annotations +import json from unittest import mock import pytest from airflow.api_fastapi.core_api.datamodels.xcom import XComCreateBody -from airflow.models import XCom from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import BaseXCom, resolve_xcom_backend +from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType @@ -39,6 +40,7 @@ TEST_XCOM_KEY = "test_xcom_key" TEST_XCOM_VALUE = {"key": "value"} +TEST_XCOM_VALUE_AS_JSON = json.dumps(TEST_XCOM_VALUE) TEST_XCOM_KEY_2 = "test_xcom_key_non_existing" TEST_DAG_ID = "test-dag-id" @@ -58,7 +60,7 @@ @provide_session def _create_xcom(key, value, backend, session=None) -> None: - backend.set( + XComModel.set( key=key, value=value, dag_id=TEST_DAG_ID, @@ -85,7 +87,7 @@ def _create_dag_run(dag_maker, session=None): class CustomXCom(BaseXCom): @classmethod - def deserialize_value(cls, xcom: XCom): + def deserialize_value(cls, xcom): return f"real deserialized {super().deserialize_value(xcom)}" def orm_deserialize_value(self): @@ -107,34 +109,15 @@ def setup(self, dag_maker) -> None: def teardown_method(self) -> None: self.clear_db() - def _create_xcom(self, key, value, backend=XCom) -> None: + def _create_xcom(self, key, value, backend=None) -> None: _create_xcom(key, value, backend) class TestGetXComEntry(TestXComEndpoint): - def test_should_respond_200_stringify(self, test_client): - self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) - response = test_client.get( - f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" - ) - assert response.status_code == 200 - - current_data = response.json() - assert current_data == { - "dag_id": TEST_DAG_ID, - "logical_date": logical_date_parsed.strftime("%Y-%m-%dT%H:%M:%SZ"), - "run_id": run_id, - "key": TEST_XCOM_KEY, - "task_id": TEST_TASK_ID, - "map_index": -1, - "timestamp": current_data["timestamp"], - "value": str(TEST_XCOM_VALUE), - } - def test_should_respond_200_native(self, test_client): self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) response = test_client.get( - f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}?stringify=false" + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" ) assert response.status_code == 200 @@ -147,7 +130,7 @@ def test_should_respond_200_native(self, test_client): "task_id": TEST_TASK_ID, "map_index": -1, "timestamp": current_data["timestamp"], - "value": TEST_XCOM_VALUE, + "value": json.dumps(TEST_XCOM_VALUE), } def test_should_respond_401(self, unauthenticated_test_client): @@ -175,7 +158,7 @@ def test_should_raise_404_for_non_existent_xcom(self, test_client): pytest.param( True, {"deserialize": True}, - f"real deserialized {TEST_XCOM_VALUE}", + f"real deserialized {TEST_XCOM_VALUE_AS_JSON}", id="enabled deserialize-true", ), pytest.param( @@ -187,25 +170,25 @@ def test_should_raise_404_for_non_existent_xcom(self, test_client): pytest.param( True, {"deserialize": False}, - f"orm deserialized {TEST_XCOM_VALUE}", + f"{TEST_XCOM_VALUE_AS_JSON}", id="enabled deserialize-false", ), pytest.param( False, {"deserialize": False}, - f"orm deserialized {TEST_XCOM_VALUE}", + f"{TEST_XCOM_VALUE_AS_JSON}", id="disabled deserialize-false", ), pytest.param( True, {}, - f"orm deserialized {TEST_XCOM_VALUE}", + f"{TEST_XCOM_VALUE_AS_JSON}", id="enabled default", ), pytest.param( False, {}, - f"orm deserialized {TEST_XCOM_VALUE}", + f"{TEST_XCOM_VALUE_AS_JSON}", id="disabled default", ), ], @@ -218,7 +201,7 @@ def test_custom_xcom_deserialize( self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE, backend=XCom) url = f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" - with mock.patch("airflow.api_fastapi.core_api.routes.public.xcom.XCom", XCom): + with mock.patch("airflow.sdk.execution_time.xcom.XCom", XCom): with conf_vars({("api", "enable_xcom_deserialize_support"): str(support_deserialize)}): response = test_client.get(url, params=params) @@ -441,7 +424,7 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti= key = f"{TEST_XCOM_KEY}-{i}" map_index = -1 - XCom.set( + XComModel.set( key=key, value=TEST_XCOM_VALUE, run_id=run_id, @@ -612,7 +595,7 @@ def test_create_xcom_entry( # Validate the created XCom response current_data = response.json() assert current_data["key"] == request_body.key - assert current_data["value"] == XCom.serialize_value(request_body.value) + assert current_data["value"] == XComModel.serialize_value(request_body.value) assert current_data["dag_id"] == dag_id assert current_data["task_id"] == task_id assert current_data["run_id"] == dag_run_id @@ -660,7 +643,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai # Ensure the XCom entry exists before updating if expected_status != 404: self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) - new_value = XCom.serialize_value(patch_body["value"]) + new_value = XComModel.serialize_value(patch_body["value"]) response = test_client.patch( f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{key}", @@ -670,7 +653,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai assert response.status_code == expected_status if expected_status == 200: - assert response.json()["value"] == XCom.serialize_value(new_value) + assert response.json()["value"] == XComModel.serialize_value(new_value) else: assert response.json()["detail"] == expected_detail check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index da5413e9b9c3b..5f56a09eaa787 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -112,6 +112,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti "max_tries": 0, "variables": [], "connections": [], + "xcom_keys_to_clear": [], } # Refresh the Task Instance from the database so that we can check the updated values @@ -188,6 +189,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance, "max_tries": 0, "variables": [], "connections": [], + "xcom_keys_to_clear": [], "next_method": "execute_complete", "next_kwargs": { "__type": "dict", @@ -231,40 +233,6 @@ def test_ti_run_state_conflict_if_not_queued( assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state - def test_xcom_cleared_when_ti_runs(self, client, session, create_task_instance, time_machine): - """ - Test that the xcoms are cleared when the Task Instance state is updated to running. - """ - instant_str = "2024-09-30T12:00:00Z" - instant = timezone.parse(instant_str) - time_machine.move_to(instant, tick=False) - - ti = create_task_instance( - task_id="test_xcom_cleared_when_ti_runs", - state=State.QUEUED, - session=session, - start_date=instant, - ) - session.commit() - - # Lets stage a xcom push - ti.xcom_push(key="key", value="value") - - response = client.patch( - f"/execution/task-instances/{ti.id}/run", - json={ - "state": "running", - "hostname": "random-hostname", - "unixname": "random-unixname", - "pid": 100, - "start_date": instant_str, - }, - ) - - assert response.status_code == 200 - # Once the task is running, we can check if xcom is cleared - assert ti.xcom_pull(task_ids="test_xcom_cleared_when_ti_runs", key="key") is None - def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instance, time_machine): """ Test that the xcoms are not cleared when the Task Instance state is re-running after deferral. diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index 2f0755fe5a236..b8d10538c5b39 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -26,7 +26,8 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.dagrun import DagRun from airflow.models.taskmap import TaskMap -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel +from airflow.serialization.serde import serialize from airflow.utils.session import create_session pytestmark = pytest.mark.db_test @@ -37,32 +38,39 @@ def reset_db(): """Reset XCom entries.""" with create_session() as session: session.query(DagRun).delete() - session.query(XCom).delete() + session.query(XComModel).delete() class TestXComsGetEndpoint: @pytest.mark.parametrize( - ("value", "expected_value"), + ("db_value"), [ - ('"value1"', '"value1"'), - ('{"key2": "value2"}', '{"key2": "value2"}'), - ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), - ('["value1"]', '["value1"]'), + ("value1"), + ({"key2": "value2"}), + ({"key2": "value2", "key3": ["value3"]}), + (["value1"]), ], ) - def test_xcom_get_from_db(self, client, create_task_instance, session, value, expected_value): + def test_xcom_get_from_db(self, client, create_task_instance, session, db_value): """Test that XCom value is returned from the database in JSON-compatible format.""" + # The tests expect serialised strings because v2 serialised and stored in the DB ti = create_task_instance() - ti.xcom_push(key="xcom_1", value=value, session=session) - session.commit() - xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() - assert xcom.value == expected_value + x = XComModel( + key="xcom_1", + value=db_value, + dag_run_id=ti.dag_run.id, + run_id=ti.run_id, + task_id=ti.task_id, + dag_id=ti.dag_id, + ) + session.add(x) + session.commit() response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1") assert response.status_code == 200 - assert response.json() == {"key": "xcom_1", "value": expected_value} + assert response.json() == {"key": "xcom_1", "value": db_value} def test_xcom_not_found(self, client, create_task_instance): response = client.get("/execution/xcoms/dag/runid/task/xcom_non_existent") @@ -106,7 +114,7 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v """ ti = create_task_instance() session.commit() - + value = serialize(value) response = client.post( f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1", json=value, @@ -115,7 +123,7 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v assert response.status_code == 201 assert response.json() == {"message": "XCom successfully set"} - xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + xcom = session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() assert xcom.value == expected_value task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() assert task_map is None, "Should not be mapped" @@ -124,17 +132,19 @@ def test_xcom_set_mapped(self, client, create_task_instance, session): ti = create_task_instance() session.commit() + value = serialize("value1") + response = client.post( f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1", params={"map_index": -1, "mapped_length": 3}, - json="value1", + json=value, ) assert response.status_code == 201 assert response.json() == {"message": "XCom successfully set"} xcom = ( - session.query(XCom) + session.query(XComModel) .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1", map_index=-1) .first() ) @@ -217,6 +227,7 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe """ ti = create_task_instance() + value = serialize(value) session.commit() client.post( f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip", @@ -224,7 +235,7 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe ) xcom = ( - session.query(XCom) + session.query(XComModel) .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="test_xcom_roundtrip") .first() ) @@ -234,3 +245,22 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe assert response.status_code == 200 assert XComResponse.model_validate_json(response.read()).value == expected_value + + +class TestXComsDeleteEndpoint: + def test_xcom_delete_endpoint(self, client, create_task_instance, session): + """Test that XCom value is deleted when Delete API is called.""" + ti = create_task_instance() + ti.xcom_push(key="xcom_1", value='"value1"', session=session) + session.commit() + + xcom = session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + assert xcom is not None + + response = client.delete(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1") + + assert response.status_code == 200 + assert response.json() == {"message": "XCom with key: xcom_1 successfully deleted."} + + xcom = session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + assert xcom is None diff --git a/tests/conftest.py b/tests/conftest.py index 28c359b81754c..c1502e319f488 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ from contextlib import contextmanager from pathlib import Path from typing import TYPE_CHECKING +from unittest import mock import pytest @@ -137,6 +138,14 @@ def test_zip_path(tmp_path: Path): return os.fspath(zipped) +@pytest.fixture +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms + + if TYPE_CHECKING: # Static checkers do not know about pytest fixtures' types and return, # In case if them distributed through third party packages. diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py index 9e21e11d8b6a2..caa824ec97539 100644 --- a/tests/decorators/test_sensor.py +++ b/tests/decorators/test_sensor.py @@ -22,7 +22,7 @@ from airflow.decorators import task from airflow.exceptions import AirflowSensorTimeout -from airflow.models import XCom +from airflow.models.xcom import XComModel from airflow.sensors.base import PokeReturnValue from airflow.utils.state import State @@ -60,10 +60,10 @@ def dummy_f(): assert ti.state == State.SUCCESS if ti.task_id == "dummy_f": assert ti.state == State.NONE - actual_xcom_value = XCom.get_one( - key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id - ) - assert actual_xcom_value == sensor_xcom_value + actual_xcom_value = XComModel.get_many( + key="return_value", task_ids="sensor_f", dag_ids=dr.dag_id, run_id=dr.run_id + ).first() + assert XComModel.deserialize_value(actual_xcom_value) == sensor_xcom_value def test_basic_sensor_success_returns_bool(self, dag_maker): @task.sensor @@ -216,7 +216,7 @@ def sensor_f(n: int): assert ti.state == State.SUCCESS if ti.task_id == "dummy_f": assert ti.state == State.SUCCESS - actual_xcom_value = XCom.get_one( - key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id - ) - assert actual_xcom_value == sensor_xcom_value + actual_xcom_value = XComModel.get_many( + key="return_value", task_ids="sensor_f", dag_ids=dr.dag_id, run_id=dr.run_id + ).first() + assert XComModel.deserialize_value(actual_xcom_value) == sensor_xcom_value diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index b3f9ee4184733..e61200ca24bfc 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -2234,7 +2234,7 @@ def test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session): was not used in the test since it would require that the task is expanded first. """ - from airflow.models.xcom import XCom + from airflow.models.xcom import XComModel @task def printx(x): @@ -2265,12 +2265,12 @@ def printx(x): # Purposely omitted RenderedTaskInstanceFields because the ti need # to be expanded but here we are mimicking and made it map_index -1 session.add(tr) - XCom.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id) + XComModel.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id) session.commit() - for table in [TaskInstanceNote, TaskReschedule, XCom]: + for table in [TaskInstanceNote, TaskReschedule, XComModel]: assert session.query(table).count() == 1 dr1.task_instance_scheduling_decisions(session) - for table in [TaskInstanceNote, TaskReschedule, XCom]: + for table in [TaskInstanceNote, TaskReschedule, XComModel]: assert session.query(table).count() == 0 diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 06548e8a40caf..870f54cdf2caf 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -19,6 +19,7 @@ import contextlib import datetime +import json import operator import os import pathlib @@ -71,7 +72,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.variable import Variable -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel from airflow.notifications.basenotifier import BaseNotifier from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator @@ -2408,9 +2409,11 @@ def _write2_post_execute(context, result): ti.run(session=session) xcom = session.scalars( - select(XCom).filter_by(dag_id=dr.dag_id, run_id=dr.run_id, task_id="write1", key="return_value") + select(XComModel).filter_by( + dag_id=dr.dag_id, run_id=dr.run_id, task_id="write1", key="return_value" + ) ).one() - assert xcom.value == "write_1 result" + assert xcom.value == json.dumps("write_1 result") events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) assert set(events) == {"write1", "write2"} @@ -4004,7 +4007,7 @@ def test_operator_field_with_serialization(self, create_task_instance): assert ser_ti.task.operator_name == "EmptyOperator" def test_clear_db_references(self, session, create_task_instance): - tables = [RenderedTaskInstanceFields, XCom] + tables = [RenderedTaskInstanceFields, XComModel] ti = create_task_instance() ti.note = "sample note" @@ -4012,7 +4015,7 @@ def test_clear_db_references(self, session, create_task_instance): session.commit() for table in [RenderedTaskInstanceFields]: session.add(table(ti)) - XCom.set(key="key", value="value", task_id=ti.task_id, dag_id=ti.dag_id, run_id=ti.run_id) + XComModel.set(key="key", value="value", task_id=ti.task_id, dag_id=ti.dag_id, run_id=ti.run_id) session.commit() for table in tables: assert session.query(table).count() == 1 diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index ce36c7b5137a4..c07c8f5acbf91 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -29,8 +29,9 @@ from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models import TaskInstance, Trigger, XCom +from airflow.models import TaskInstance, Trigger from airflow.models.asset import AssetEvent, AssetModel, asset_trigger_association_table +from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import ( @@ -241,7 +242,7 @@ def test_submit_event_task_end(mock_utcnow, session, create_task_instance, event session.commit() def get_xcoms(ti): - return XCom.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id).all() + return XComModel.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id).all() # now for the real test # first check initial state @@ -264,7 +265,10 @@ def get_xcoms(ti): assert ti.end_date == now assert ti.duration is not None actual_xcoms = {x.key: x.value for x in get_xcoms(ti)} - assert actual_xcoms == {"return_value": "xcomret", "a": "b", "c": "d"} + expected_xcoms = {} + for k, v in {"return_value": "xcomret", "a": "b", "c": "d"}.items(): + expected_xcoms[k] = json.dumps(v) + assert actual_xcoms == expected_xcoms def test_assign_unassigned(session, create_task_instance): diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 45d7b264c44bd..c01c6a88c0520 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -27,8 +27,9 @@ from airflow.configuration import conf from airflow.models.dagrun import DagRun, DagRunType from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend +from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend from airflow.settings import json from airflow.utils import timezone from airflow.utils.session import create_session @@ -52,7 +53,7 @@ def reset_db(): """Reset XCom entries.""" with create_session() as session: session.query(DagRun).delete() - session.query(XCom).delete() + session.query(XComModel).delete() @pytest.fixture @@ -114,52 +115,17 @@ def test_resolve_xcom_class(self): def test_resolve_xcom_class_fallback_to_basexcom(self): cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls.serialize_value([1]) == "[1]" + assert cls.serialize_value([1]) == [1] @conf_vars({("core", "xcom_backend"): "to be removed"}) def test_resolve_xcom_class_fallback_to_basexcom_no_config(self): conf.remove_option("core", "xcom_backend") cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls.serialize_value([1]) == "[1]" - - @mock.patch("airflow.models.xcom.XCom.orm_deserialize_value") - def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize): - instance = BaseXCom( - key="key", - value="value", - timestamp=timezone.utcnow(), - logical_date=timezone.utcnow(), - task_id="task_id", - dag_id="dag_id", - ) - instance.init_on_load() - mock_orm_deserialize.assert_called_once_with() - - @conf_vars({("core", "xcom_backend"): "tests.models.test_xcom.CustomXCom"}) - def test_get_one_custom_backend_no_use_orm_deserialize_value(self, task_instance, session): - """Test that XCom.get_one does not call orm_deserialize_value""" - XCom = resolve_xcom_backend() - XCom.set( - key=XCOM_RETURN_KEY, - value={"key": "value"}, - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) - - value = XCom.get_one( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) - assert value == {"key": "value"} - XCom.orm_deserialize_value.assert_not_called() + assert cls.serialize_value([1]) == [1] - @mock.patch("airflow.models.xcom.conf.getimport") - def test_set_serialize_call_current_signature(self, get_import, task_instance): + @mock.patch("airflow.sdk.execution_time.xcom.conf.getimport") + def test_set_serialize_call_current_signature(self, get_import, task_instance, mock_supervisor_comms): """ When XCom.serialize_value includes params logical_date, key, dag_id, task_id and run_id, then XCom.set should pass all of them. @@ -210,7 +176,7 @@ def serialize_value( @pytest.fixture def push_simple_json_xcom(session): def func(*, ti: TaskInstance, key: str, value): - return XCom.set( + return XComModel.set( key=key, value=value, dag_id=ti.dag_id, @@ -229,14 +195,14 @@ def setup_for_xcom_get_one(self, task_instance, push_simple_json_xcom): @pytest.mark.usefixtures("setup_for_xcom_get_one") def test_xcom_get_one(self, session, task_instance): - stored_value = XCom.get_one( + stored_value = XComModel.get_many( key="xcom_1", - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, run_id=task_instance.run_id, session=session, - ) - assert stored_value == {"key": "value"} + ).first() + assert XComModel.deserialize_value(stored_value) == {"key": "value"} @pytest.fixture def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom): @@ -256,15 +222,15 @@ def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simpl def test_xcom_get_one_from_prior_date(self, session, tis_for_xcom_get_one_from_prior_date): _, ti2 = tis_for_xcom_get_one_from_prior_date - retrieved_value = XCom.get_one( + retrieved_value = XComModel.get_many( run_id=ti2.run_id, key="xcom_1", - task_id="task_1", - dag_id="dag", + task_ids="task_1", + dag_ids="dag", include_prior_dates=True, session=session, - ) - assert retrieved_value == {"key": "value"} + ).first() + assert XComModel.deserialize_value(retrieved_value) == {"key": "value"} @pytest.fixture def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simple_json_xcom): @@ -272,7 +238,7 @@ def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simp @pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value") def test_xcom_get_many_single_argument_value(self, session, task_instance): - stored_xcoms = XCom.get_many( + stored_xcoms = XComModel.get_many( key="xcom_1", dag_ids=task_instance.dag_id, task_ids=task_instance.task_id, @@ -281,7 +247,7 @@ def test_xcom_get_many_single_argument_value(self, session, task_instance): ).all() assert len(stored_xcoms) == 1 assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} + assert stored_xcoms[0].value == json.dumps({"key": "value"}) @pytest.fixture def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_json_xcom): @@ -291,7 +257,7 @@ def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_jso @pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks") def test_xcom_get_many_multiple_tasks(self, session, task_instance): - stored_xcoms = XCom.get_many( + stored_xcoms = XComModel.get_many( key="xcom_1", dag_ids=task_instance.dag_id, task_ids=["task_1", "task_2"], @@ -299,7 +265,7 @@ def test_xcom_get_many_multiple_tasks(self, session, task_instance): session=session, ) sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] - assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}] + assert sorted_values == [json.dumps({"key1": "value1"}), json.dumps({"key2": "value2"})] @pytest.fixture def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_simple_json_xcom): @@ -313,7 +279,7 @@ def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_sim def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_from_prior_dates): ti1, ti2 = tis_for_xcom_get_many_from_prior_dates - stored_xcoms = XCom.get_many( + stored_xcoms = XComModel.get_many( run_id=ti2.run_id, key="xcom_1", dag_ids="dag", @@ -323,7 +289,9 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro ) # The retrieved XComs should be ordered by logical date, latest first. - assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}] + assert [x.value for x in stored_xcoms] == list( + map(lambda j: json.dumps(j), [{"key2": "value2"}, {"key1": "value1"}]) + ) assert [x.logical_date for x in stored_xcoms] == [ti2.logical_date, ti1.logical_date] @@ -340,7 +308,7 @@ class TestXComSet: ], ) def test_xcom_set(self, session, task_instance, key, value, expected_value): - XCom.set( + XComModel.set( key=key, value=value, dag_id=task_instance.dag_id, @@ -348,10 +316,10 @@ def test_xcom_set(self, session, task_instance, key, value, expected_value): run_id=task_instance.run_id, session=session, ) - stored_xcoms = session.query(XCom).all() + stored_xcoms = session.query(XComModel).all() assert stored_xcoms[0].key == key - assert isinstance(stored_xcoms[0].value, type(expected_value)) - assert stored_xcoms[0].value == expected_value + assert isinstance(stored_xcoms[0].value, type(json.dumps(expected_value))) + assert stored_xcoms[0].value == json.dumps(expected_value) assert stored_xcoms[0].dag_id == "dag" assert stored_xcoms[0].task_id == "task_1" assert stored_xcoms[0].logical_date == task_instance.logical_date @@ -362,8 +330,8 @@ def setup_for_xcom_set_again_replace(self, task_instance, push_simple_json_xcom) @pytest.mark.usefixtures("setup_for_xcom_set_again_replace") def test_xcom_set_again_replace(self, session, task_instance): - assert session.query(XCom).one().value == {"key1": "value1"} - XCom.set( + assert session.query(XComModel).one().value == json.dumps({"key1": "value1"}) + XComModel.set( key="xcom_1", value={"key2": "value2"}, dag_id=task_instance.dag_id, @@ -371,7 +339,7 @@ def test_xcom_set_again_replace(self, session, task_instance): run_id=task_instance.run_id, session=session, ) - assert session.query(XCom).one().value == {"key2": "value2"} + assert session.query(XComModel).one().value == json.dumps({"key2": "value2"}) class TestXComClear: @@ -382,22 +350,23 @@ def setup_for_xcom_clear(self, task_instance, push_simple_json_xcom): @pytest.mark.usefixtures("setup_for_xcom_clear") @mock.patch("airflow.models.xcom.XCom.purge") def test_xcom_clear(self, mock_purge, session, task_instance): - assert session.query(XCom).count() == 1 - XCom.clear( + assert session.query(XComModel).count() == 1 + XComModel.clear( dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, session=session, ) - assert session.query(XCom).count() == 0 - assert mock_purge.call_count == 1 + assert session.query(XComModel).count() == 0 + # purge will not be done when we clear, will be handled in task sdk + assert mock_purge.call_count == 0 @pytest.mark.usefixtures("setup_for_xcom_clear") def test_xcom_clear_different_run(self, session, task_instance): - XCom.clear( + XComModel.clear( dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id="different_run", session=session, ) - assert session.query(XCom).count() == 1 + assert session.query(XComModel).count() == 1 diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index e8464dde2b398..2ebf29bc9c4e7 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -32,6 +32,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers.standard.triggers.external_task import DagStateTrigger +from airflow.sdk.execution_time.comms import XComResult from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -122,7 +123,7 @@ def assert_extra_link(self, triggered_dag_run, triggering_task, session): } assert expected_args in args - def test_trigger_dagrun(self, dag_maker): + def test_trigger_dagrun(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator.""" with time_machine.travel("2025-02-18T08:04:46Z", tick=False): with dag_maker( @@ -144,6 +145,8 @@ def test_trigger_dagrun(self, dag_maker): assert actual_run_id == expected_run_id + mock_supervisor_comms.get_message.return_value = XComResult(key="xcom_key", value=dagrun.run_id) + self.assert_extra_link(dagrun, task, dag_maker.session) def test_trigger_dagrun_custom_run_id(self, dag_maker): @@ -165,7 +168,7 @@ def test_trigger_dagrun_custom_run_id(self, dag_maker): assert len(dagruns) == 1 assert dagruns[0].run_id == "custom_run_id" - def test_trigger_dagrun_with_logical_date(self, dag_maker): + def test_trigger_dagrun_with_logical_date(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator with custom logical_date.""" custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) with dag_maker( @@ -188,9 +191,10 @@ def test_trigger_dagrun_with_logical_date(self, dag_maker): assert dagrun.run_id == DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=custom_logical_date, run_after=custom_logical_date ) + mock_supervisor_comms.get_message.return_value = XComResult(key="xcom_key", value=dagrun.run_id) self.assert_extra_link(dagrun, task, session) - def test_trigger_dagrun_twice(self, dag_maker): + def test_trigger_dagrun_twice(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator with custom logical_date.""" utc_now = timezone.utcnow() run_id = f"manual__{utc_now.isoformat()}" @@ -227,9 +231,12 @@ def test_trigger_dagrun_twice(self, dag_maker): triggered_dag_run = dagruns[0] assert triggered_dag_run.run_type == DagRunType.MANUAL assert triggered_dag_run.logical_date == utc_now + mock_supervisor_comms.get_message.return_value = XComResult( + key="xcom_key", value=triggered_dag_run.run_id + ) self.assert_extra_link(triggered_dag_run, task, dag_maker.session) - def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): + def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator with custom logical_date and scheduled dag_run.""" utc_now = timezone.utcnow() run_id = f"scheduled__{utc_now.isoformat()}" @@ -266,9 +273,12 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): assert len(dagruns) == 1 triggered_dag_run = dagruns[0] assert triggered_dag_run.logical_date == utc_now + mock_supervisor_comms.get_message.return_value = XComResult( + key="xcom_key", value=triggered_dag_run.run_id + ) self.assert_extra_link(triggered_dag_run, task, dag_maker.session) - def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): + def test_trigger_dagrun_with_templated_logical_date(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator with templated logical_date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True @@ -289,9 +299,12 @@ def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): triggered_dag_run = dagruns[0] assert triggered_dag_run.run_type == DagRunType.MANUAL assert triggered_dag_run.logical_date == DEFAULT_DATE + mock_supervisor_comms.get_message.return_value = XComResult( + key="xcom_key", value=triggered_dag_run.run_id + ) self.assert_extra_link(triggered_dag_run, task, session) - def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker): + def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker, mock_supervisor_comms): """Test TriggerDagRunOperator with templated trigger dag id.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True @@ -311,6 +324,9 @@ def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker): triggered_dag_run = dagruns[0] assert triggered_dag_run.run_type == DagRunType.MANUAL assert triggered_dag_run.dag_id == TRIGGERED_DAG_ID + mock_supervisor_comms.get_message.return_value = XComResult( + key="xcom_key", value=triggered_dag_run.run_id + ) self.assert_extra_link(triggered_dag_run, task, session) def test_trigger_dagrun_operator_conf(self, dag_maker): diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index f7e1021ecb142..30ba877b2c02b 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -44,7 +44,7 @@ from airflow.executors.sequential_executor import SequentialExecutor from airflow.models import TaskInstance, TaskReschedule from airflow.models.trigger import TriggerFailureReason -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel from airflow.providers.celery.executors.celery_executor import CeleryExecutor from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor from airflow.providers.standard.operators.empty import EmptyOperator @@ -1275,9 +1275,10 @@ def test_sensor_with_xcom(self, make_sensor): assert ti.state == State.SUCCESS if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - actual_xcom_value = XCom.get_one( - key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id - ) + actual_xcom_value = XComModel.get_many( + key="return_value", task_ids=SENSOR_OP, dag_ids=dr.dag_id, run_id=dr.run_id + ).first() + actual_xcom_value = XComModel.deserialize_value(actual_xcom_value) assert actual_xcom_value == xcom_value def test_sensor_with_xcom_fails(self, make_sensor): @@ -1293,9 +1294,11 @@ def test_sensor_with_xcom_fails(self, make_sensor): assert ti.state == State.FAILED if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - actual_xcom_value = XCom.get_one( - key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id - ) + actual_xcom_value = XComModel.get_many( + key="return_value", task_ids=SENSOR_OP, dag_ids=dr.dag_id, run_id=dr.run_id + ).first() + if actual_xcom_value: + actual_xcom_value = XComModel.deserialize_value(actual_xcom_value) assert actual_xcom_value is None @pytest.mark.parametrize( diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index cb0afc2f8cb6f..8acde585817b3 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -62,7 +62,7 @@ from airflow.models.dagbag import DagBag from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.mappedoperator import MappedOperator -from airflow.models.xcom import XCom +from airflow.models.xcom import XComModel from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.sensors.bash import BashSensor @@ -1146,7 +1146,7 @@ def test_extra_serialized_field_and_operator_links( dr = dag_maker.create_dagrun(logical_date=test_date) (ti,) = dr.task_instances - XCom.set( + XComModel.set( key="search_query", value=bash_command, task_id=simple_task.task_id, @@ -1158,7 +1158,7 @@ def test_extra_serialized_field_and_operator_links( # Test Deserialized inbuilt link for name, expected in links.items(): # staging the part where a task at runtime pushes xcom for extra links - XCom.set( + XComModel.set( key=simple_task.operator_extra_links[c].xcom_key, value=expected, task_id=simple_task.task_id,