Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
e76e142
AIP-72: Handle Custom XCom Backend on Task SDK
amoghrajesh Mar 4, 2025
6b8424d
adding unit tests
amoghrajesh Mar 4, 2025
5cdcf50
routine push
amoghrajesh Mar 5, 2025
504c76f
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 7, 2025
72c95ed
adding datamodel for xcom now
amoghrajesh Mar 7, 2025
37b0b1c
introducing delete xcom API
amoghrajesh Mar 8, 2025
dd7317a
tests for the delete API
amoghrajesh Mar 8, 2025
2e6979d
do not clear xcoms in /run endpoint
amoghrajesh Mar 8, 2025
35de648
send requests to clear xcoms while marking a task as running
amoghrajesh Mar 8, 2025
c0b405c
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 10, 2025
a11f3d0
Use the API to get xcoms in baseoperatorlink
amoghrajesh Mar 10, 2025
adfb229
updating taskinstance model to use the right model
amoghrajesh Mar 10, 2025
2b5398c
use XcomModel now in core api
amoghrajesh Mar 10, 2025
dda7e46
changes to models/xcom, task runner and execution time xcom introduced
amoghrajesh Mar 10, 2025
ad6304f
changes to supervisor and comms
amoghrajesh Mar 10, 2025
7cd0ba5
fixing models in xcom.py
amoghrajesh Mar 10, 2025
06b5ba8
fixing the object store backend
amoghrajesh Mar 10, 2025
1ce5b20
lazy loading for execution time xcoms
amoghrajesh Mar 10, 2025
b49311e
an attempt to fix the tests and mypy
amoghrajesh Mar 10, 2025
04b4f05
adding back xcom_return_key
amoghrajesh Mar 10, 2025
574d470
task sdk mypy
amoghrajesh Mar 10, 2025
13a6983
mypy providers
amoghrajesh Mar 10, 2025
78892f3
fixing tests for TestXComObjectStorageBackend
amoghrajesh Mar 11, 2025
a1ded63
fixing mypy
amoghrajesh Mar 11, 2025
dc82948
as xcom
amoghrajesh Mar 11, 2025
c4ba8af
remove unwanted log and add docstring
amoghrajesh Mar 11, 2025
97e4c3f
changing few more older occurences
amoghrajesh Mar 11, 2025
af42da7
fixing mypy for providers
amoghrajesh Mar 11, 2025
ef2ffe9
more occurences of Xcom
amoghrajesh Mar 11, 2025
d2af742
fixing compat tests
amoghrajesh Mar 11, 2025
b7e2f7f
fixing devel-common
amoghrajesh Mar 11, 2025
7f85d95
fixing task instance tests
amoghrajesh Mar 12, 2025
bcaff13
attemtping to fix tests
amoghrajesh Mar 12, 2025
14465ad
fixing another test
amoghrajesh Mar 12, 2025
4b18308
review from ash
amoghrajesh Mar 12, 2025
4eebc45
fixing checks for tests
amoghrajesh Mar 12, 2025
4b1a2b3
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 12, 2025
bcc297b
use xcommodel
amoghrajesh Mar 12, 2025
e19b14d
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 12, 2025
3ba63f5
fixing some tests
amoghrajesh Mar 12, 2025
99d77c5
fixing xcom core tests
amoghrajesh Mar 13, 2025
f757168
fixing triggered dag run tests
amoghrajesh Mar 13, 2025
86b6fbe
fixing google tests
amoghrajesh Mar 13, 2025
7231969
fixing sql tests
amoghrajesh Mar 13, 2025
ec1a9f5
fixing amazon tests
amoghrajesh Mar 13, 2025
f2569d0
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 13, 2025
70ab429
fixing tests cos of rebase
amoghrajesh Mar 13, 2025
78cb1c7
fixing standard operator tests
amoghrajesh Mar 13, 2025
c309fb9
brining back LazySelectSequence for tests
amoghrajesh Mar 13, 2025
dd3d280
fixing dbt links and cncf tests
amoghrajesh Mar 13, 2025
a2348bd
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 13, 2025
6d42c7c
fixing core tests
amoghrajesh Mar 13, 2025
0f1ee28
fixing microsoft tests
amoghrajesh Mar 13, 2025
13bc4a4
removing wrong import
amoghrajesh Mar 13, 2025
9ce3063
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 13, 2025
8c6fd85
fixing backend tests
amoghrajesh Mar 14, 2025
3e38920
tryna fix core and amazon compat tests
amoghrajesh Mar 14, 2025
5f06413
removing residue imports
amoghrajesh Mar 14, 2025
78bf7db
adding backcompat to fixture
amoghrajesh Mar 14, 2025
9ccc3f8
adding backcompat to fixture and tests
amoghrajesh Mar 15, 2025
40e82d0
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 16, 2025
dbc3d62
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 16, 2025
b4307f5
yeild not return
amoghrajesh Mar 16, 2025
a987cc0
drop this commit
amoghrajesh Mar 16, 2025
e72e27b
adding a return to fixture
amoghrajesh Mar 16, 2025
6231668
removcing unwanted stuff
amoghrajesh Mar 16, 2025
24724c3
fix faulty rebase
amoghrajesh Mar 16, 2025
ed6eb59
might need to drop this commit as well
amoghrajesh Mar 16, 2025
e5788f0
compat tests..
amoghrajesh Mar 16, 2025
a3d6896
not a big fan but fixing test_backend
amoghrajesh Mar 17, 2025
a3f6602
review comments
amoghrajesh Mar 17, 2025
be704d8
review comments from ash and kaxil
amoghrajesh Mar 17, 2025
c07c981
review comments from ash and kaxil
amoghrajesh Mar 17, 2025
c257369
Merge branch 'main' into AIP-72-custom-xcom-backend
amoghrajesh Mar 17, 2025
b8a9d7c
fixing import in backend
amoghrajesh Mar 17, 2025
826c4d8
fix compat test
amoghrajesh Mar 17, 2025
2d03273
fix compat test -- last
amoghrajesh Mar 17, 2025
479f647
fix compat test -- last
amoghrajesh Mar 17, 2025
b2a934e
ash comments
amoghrajesh Mar 17, 2025
e42a929
ash comments about imports
amoghrajesh Mar 17, 2025
f4347ac
fixing core test
amoghrajesh Mar 17, 2025
dc26349
this should fix the providers tests
amoghrajesh Mar 17, 2025
6296659
another try compat tests
amoghrajesh Mar 17, 2025
7291ec8
another try compat tests
amoghrajesh Mar 17, 2025
64b5968
fixing tests
amoghrajesh Mar 17, 2025
9540fbb
empty session purge
amoghrajesh Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4641,7 +4641,7 @@ paths:
required: false
schema:
type: boolean
default: true
default: false
title: Stringify
responses:
'200':
Expand Down
96 changes: 57 additions & 39 deletions airflow/api_fastapi/core_api/routes/public/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from airflow.api_fastapi.core_api.security import ReadableXComFilterDep, requires_access_dag
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import TaskNotFound
from airflow.models import DAG, DagRun as DR, XCom
from airflow.models import DAG, DagRun as DR
from airflow.models.xcom import XComModel
from airflow.settings import conf

xcom_router = AirflowRouter(
Expand All @@ -63,22 +64,25 @@ def get_xcom_entry(
session: SessionDep,
map_index: Annotated[int, Query(ge=-1)] = -1,
deserialize: Annotated[bool, Query()] = False,
stringify: Annotated[bool, Query()] = True,
stringify: Annotated[bool, Query()] = False,
) -> XComResponseNative | XComResponseString:
"""Get an XCom entry."""
if deserialize:
if not conf.getboolean("api", "enable_xcom_deserialize_support", fallback=False):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "XCom deserialization is disabled in configuration."
)
query = select(XCom, XCom.value)
query = select(XComModel, XComModel.value)
else:
query = select(XCom)
query = select(XComModel)

query = query.where(
XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key, XCom.map_index == map_index
XComModel.dag_id == dag_id,
XComModel.task_id == task_id,
XComModel.key == xcom_key,
XComModel.map_index == map_index,
)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
query = query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id))
query = query.where(DR.run_id == dag_run_id)

if deserialize:
Expand All @@ -90,6 +94,8 @@ def get_xcom_entry(
raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found")

if deserialize:
from airflow.sdk.execution_time.xcom import XCom

xcom, value = item
xcom_stub = copy.copy(xcom)
xcom_stub.value = value
Expand Down Expand Up @@ -127,19 +133,19 @@ def get_xcom_entries(

This endpoint allows specifying `~` as the dag_id, dag_run_id, task_id to retrieve XCom entries for all DAGs.
"""
query = select(XCom)
query = select(XComModel)
if dag_id != "~":
query = query.where(XCom.dag_id == dag_id)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
query = query.where(XComModel.dag_id == dag_id)
query = query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id))

if task_id != "~":
query = query.where(XCom.task_id == task_id)
query = query.where(XComModel.task_id == task_id)
if dag_run_id != "~":
query = query.where(DR.run_id == dag_run_id)
if map_index is not None:
query = query.where(XCom.map_index == map_index)
query = query.where(XComModel.map_index == map_index)
if xcom_key is not None:
query = query.where(XCom.key == xcom_key)
query = query.where(XComModel.key == xcom_key)

query, total_entries = paginated_select(
statement=query,
Expand All @@ -148,7 +154,9 @@ def get_xcom_entries(
limit=limit,
session=session,
)
query = query.order_by(XCom.dag_id, XCom.task_id, XCom.run_id, XCom.map_index, XCom.key)
query = query.order_by(
XComModel.dag_id, XComModel.task_id, XComModel.run_id, XComModel.map_index, XComModel.key
)
xcoms = session.scalars(query)
return XComCollectionResponse(xcom_entries=xcoms, total_entries=total_entries)

Expand Down Expand Up @@ -197,38 +205,48 @@ def create_xcom_entry(
)

# Check existing XCom
if XCom.get_one(
already_existing_query = XComModel.get_many(
key=request_body.key,
task_id=task_id,
dag_id=dag_id,
task_ids=task_id,
dag_ids=dag_id,
run_id=dag_run_id,
map_index=request_body.map_index,
map_indexes=request_body.map_index,
session=session,
):
)
result = already_existing_query.with_entities(XComModel.value).first()
if result:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"The XCom with key: `{request_body.key}` with mentioned task instance already exists.",
)

# Create XCom entry
XCom.set(
dag_id=dag_id,
task_id=task_id,
run_id=dag_run_id,
try:
value = XComModel.serialize_value(request_body.value)
except (ValueError, TypeError):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{request_body.key}`"
)

new = XComModel(
dag_run_id=dag_run.id,
key=request_body.key,
value=XCom.serialize_value(request_body.value),
value=value,
run_id=dag_run_id,
task_id=task_id,
dag_id=dag_id,
map_index=request_body.map_index,
session=session,
)
session.add(new)
session.flush()

xcom = session.scalar(
select(XCom)
select(XComModel)
.filter(
XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.run_id == dag_run_id,
XCom.key == request_body.key,
XCom.map_index == request_body.map_index,
XComModel.dag_id == dag_id,
XComModel.task_id == task_id,
XComModel.run_id == dag_run_id,
XComModel.key == request_body.key,
XComModel.map_index == request_body.map_index,
)
.limit(1)
)
Expand Down Expand Up @@ -260,15 +278,15 @@ def update_xcom_entry(
) -> XComResponseNative:
"""Update an existing XCom entry."""
# Check if XCom entry exists
xcom_new_value = XCom.serialize_value(patch_body.value)
xcom_new_value = XComModel.serialize_value(patch_body.value)
xcom_entry = session.scalar(
select(XCom)
select(XComModel)
.where(
XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.run_id == dag_run_id,
XCom.key == xcom_key,
XCom.map_index == patch_body.map_index,
XComModel.dag_id == dag_id,
XComModel.task_id == task_id,
XComModel.run_id == dag_run_id,
XComModel.key == xcom_key,
XComModel.map_index == patch_body.map_index,
)
.limit(1)
)
Expand All @@ -280,6 +298,6 @@ def update_xcom_entry(
)

# Update XCom entry
xcom_entry.value = XCom.serialize_value(xcom_new_value)
xcom_entry.value = XComModel.serialize_value(xcom_new_value)

return XComResponseNative.model_validate(xcom_entry)
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from airflow.models.dag import DagModel, DagRun, DagTag
from airflow.models.dagwarning import DagWarning
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.xcom import XCom
from airflow.models.xcom import XComModel

if TYPE_CHECKING:
from sqlalchemy.sql import Select
Expand Down Expand Up @@ -132,7 +132,7 @@ class PermittedXComFilter(PermittedDagFilter):
"""A parameter that filters the permitted XComs for the user."""

def to_orm(self, select: Select) -> Select:
return select.where(XCom.dag_id.in_(self.value))
return select.where(XComModel.dag_id.in_(self.value))


class PermittedTagFilter(PermittedDagFilter):
Expand Down
3 changes: 3 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ class TIRunContext(BaseModel):
Can either be a "decorated" dict, or a string encrypted with the shared Fernet key.
"""

xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
"""List of Xcom keys that need to be cleared and purged on by the worker."""


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""
Expand Down
23 changes: 14 additions & 9 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XCom
from airflow.models.xcom import XComModel
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

Expand Down Expand Up @@ -187,18 +187,22 @@ def ti_run(
if not dr:
raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.")

# Clear XCom data for the task instance since we are certain it is executing
# Send the keys to the SDK so that the client requests to clear those XComs from the server.
# The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends
# too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they
# can issue the purge.

# However, do not clear it for deferral
xcom_keys = []
if not ti.next_method:
map_index = None if ti.map_index < 0 else ti.map_index
log.info("Clearing xcom data for task id: %s", ti_id_str)
XCom.clear(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=map_index,
session=session,
query = session.query(XComModel.key).filter_by(
dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id
)
if map_index is not None:
query = query.filter_by(map_index=map_index)

xcom_keys = [row.key for row in session.execute(query).all()]

task_reschedule_count = (
session.query(
Expand All @@ -216,6 +220,7 @@ def ti_run(
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
xcom_keys_to_clear=xcom_keys,
)

# Only set if they are non-null
Expand Down
66 changes: 58 additions & 8 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from fastapi import Body, Depends, HTTPException, Query, Response, status
from pydantic import JsonValue
from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
Expand All @@ -30,7 +31,7 @@
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import BaseXCom
from airflow.models.xcom import XComModel
from airflow.utils.db import get_query_count

# TODO: Add dependency on JWT token
Expand Down Expand Up @@ -62,7 +63,7 @@ async def xcom_query(
},
)

query = BaseXCom.get_many(
query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
Expand Down Expand Up @@ -126,7 +127,7 @@ def get_xcom(
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
# so we override that query value
xcom_query = xcom_query.filter(BaseXCom.map_index == map_index)
xcom_query = xcom_query.filter(XComModel.map_index == map_index)
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
Expand Down Expand Up @@ -222,18 +223,39 @@ def set_xcom(
# TODO: Can/should we check if a client _hasn't_ provided this for an upstream of a mapped task? That
# means loading the serialized dag and that seems like a relatively costly operation for minimal benefit
# (the mapped task would fail in a moment as it can't be expanded anyway.)
from airflow.models.dagrun import DagRun

if not run_id:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found")

dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
if dag_run_id is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}")

# Remove duplicate XComs and insert a new one.
session.execute(
delete(XComModel).where(
XComModel.key == key,
XComModel.run_id == run_id,
XComModel.task_id == task_id,
XComModel.dag_id == dag_id,
XComModel.map_index == map_index,
)
)

# We use `BaseXCom.set` to set XComs directly to the database, bypassing the XCom Backend.
try:
BaseXCom.set(
# We expect serialised value from the caller - sdk, do not serialise in here
new = XComModel(
dag_run_id=dag_run_id,
key=key,
value=value,
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
session=session,
task_id=task_id,
dag_id=dag_id,
map_index=map_index,
)
session.add(new)
session.flush()
except TypeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -246,6 +268,34 @@ def set_xcom(
return {"message": "XCom successfully set"}


@router.delete(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
description="Delete a single XCom Value",
)
def delete_xcom(
session: SessionDep,
token: deps.TokenDep,
dag_id: str,
run_id: str,
task_id: str,
key: str,
):
if not has_xcom_access(dag_id, run_id, task_id, key, token, write=True):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to delete XCom with key '{key}'",
},
)

query = session.query(XComModel).where(XComModel.key == key).first()
session.delete(query)
session.commit()
return {"message": f"XCom with key: {key} successfully deleted."}


def has_xcom_access(
dag_id: str, run_id: str, task_id: str, xcom_key: str, token: TIToken, write: bool = False
) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ core:
version_added: 1.10.12
type: string
example: "path.to.CustomXCom"
default: "airflow.models.xcom.BaseXCom"
default: "airflow.sdk.execution_time.xcom.BaseXCom"
lazy_load_plugins:
description: |
By default Airflow plugins are lazily-loaded (only loaded when required). Set it to ``False``,
Expand Down
Loading