Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5403,9 +5403,7 @@ paths:
description: Successful Response
content:
application/json:
schema:
type: 'null'
title: Response Delete Task Instance
schema: {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand why this was updated.

'401':
content:
application/json:
Expand Down Expand Up @@ -7743,9 +7741,7 @@ paths:
description: Successful Response
content:
application/json:
schema:
type: 'null'
title: Response Reparse Dag File
schema: {}
'401':
content:
application/json:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import copy
from typing import Annotated
from urllib.parse import unquote

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import and_, select
Expand Down Expand Up @@ -80,6 +81,7 @@ def get_xcom_entry(
stringify: Annotated[bool, Query()] = False,
) -> XComResponseNative | XComResponseString:
"""Get an XCom entry."""
xcom_key = unquote(xcom_key)
xcom_query = XComModel.get_many(
run_id=dag_run_id,
key=xcom_key,
Expand Down Expand Up @@ -156,6 +158,7 @@ def get_xcom_entries(

This endpoint allows specifying `~` as the dag_id, dag_run_id, task_id to retrieve XCom entries for all DAGs.
"""
xcom_key = unquote(xcom_key) if xcom_key else None
query = select(XComModel)
if dag_id != "~":
query = query.where(XComModel.dag_id == dag_id)
Expand Down Expand Up @@ -242,6 +245,7 @@ def create_xcom_entry(
)

# Check existing XCom
request_body.key = unquote(request_body.key)
already_existing_query = XComModel.get_many(
key=request_body.key,
task_ids=task_id,
Expand Down Expand Up @@ -315,6 +319,7 @@ def update_xcom_entry(
) -> XComResponseNative:
"""Update an existing XCom entry."""
# Check if XCom entry exists
xcom_key = unquote(xcom_key)
xcom_new_value = XComModel.serialize_value(patch_body.value)
xcom_entry = session.scalar(
select(XComModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
from typing import Annotated
from urllib.parse import unquote

from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status
from pydantic import BaseModel, JsonValue, StringConstraints
Expand Down Expand Up @@ -78,6 +79,7 @@ async def xcom_query(
key: str,
map_index: Annotated[int | None, Query()] = None,
) -> Select:
key = unquote(key)
query = XComModel.get_many(
run_id=run_id,
key=key,
Expand Down Expand Up @@ -143,6 +145,7 @@ def get_xcom(
params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
key = unquote(key)
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
Expand Down Expand Up @@ -196,6 +199,7 @@ def get_mapped_xcom_by_index(
offset: int,
session: SessionDep,
) -> XComSequenceIndexResponse:
key = unquote(key)
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
Expand Down Expand Up @@ -240,6 +244,7 @@ def get_mapped_xcom_by_slice(
params: Annotated[GetXComSliceFilterParams, Query()],
session: SessionDep,
) -> XComSequenceSliceResponse:
key = unquote(key)
query = XComModel.get_many(
run_id=run_id,
key=key,
Expand Down Expand Up @@ -360,7 +365,7 @@ def set_xcom(
"message": "XCom key must be a non-empty string.",
},
)

key = unquote(key)
if mapped_length is not None:
task_map = TaskMap(
dag_id=dag_id,
Expand Down Expand Up @@ -444,6 +449,7 @@ def delete_xcom(
map_index: Annotated[int, Query()] = -1,
):
"""Delete a single XCom Value."""
key = unquote(key)
query = delete(XComModel).where(
XComModel.key == key,
XComModel.run_id == run_id,
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from functools import cache
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import quote

import certifi
import httpx
Expand Down Expand Up @@ -418,6 +419,7 @@ def __init__(self, client: Client):

def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse:
"""Get the number of mapped XCom values."""
key = quote(key, safe="")
resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}")

# content_range: str | None
Expand All @@ -444,6 +446,7 @@ def get(
params.update({"map_index": map_index})
if include_prior_dates:
params.update({"include_prior_dates": include_prior_dates})
key = quote(key, safe="")
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
except ServerResponseError as e:
Expand Down Expand Up @@ -483,6 +486,7 @@ def set(
params = {"map_index": map_index}
if mapped_length is not None and mapped_length >= 0:
params["mapped_length"] = mapped_length
key = quote(key, safe="")
self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
Expand All @@ -501,6 +505,7 @@ def delete(
params = {}
if map_index is not None and map_index >= 0:
params = {"map_index": map_index}
key = quote(key, safe="")
self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
Expand All @@ -515,6 +520,7 @@ def get_sequence_item(
key: str,
offset: int,
) -> XComSequenceIndexResponse | ErrorResponse:
key = quote(key, safe="")
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}")
except ServerResponseError as e:
Expand Down Expand Up @@ -553,6 +559,7 @@ def get_sequence_slice(
step: int | None,
include_prior_dates: bool = False,
) -> XComSequenceSliceResponse:
key = quote(key, safe="")
params = {}
if start is not None:
params["start"] = start
Expand Down
5 changes: 4 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def xcom_pull(
a non-str iterable), a list of matching XComs is returned. Elements in
the list is ordered by item ordering in ``task_id`` and ``map_index``.
"""
key = quote(key, safe="")
if dag_id is None:
dag_id = self.dag_id
if run_id is None:
Expand Down Expand Up @@ -1363,8 +1364,10 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)

for k, v in result.items():
ti.xcom_push(k, v)
encoded_key = quote(k, safe="")
ti.xcom_push(encoded_key, v)

_xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)

Expand Down
80 changes: 80 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
from urllib.parse import quote

import pandas as pd
import pytest
Expand Down Expand Up @@ -1986,6 +1987,85 @@ def test_xcom_clearing_without_keys_to_clear(self, create_runtime_ti, mock_super

mock_delete.assert_not_called()

def test_xcom_push_pull_with_slash_in_key(self, create_runtime_ti, mock_supervisor_comms):
"""
Ensure that XCom keys containing slashes are correctly quoted/unquoted
and do not break API routes (no 400/404).
"""

class PushOperator(BaseOperator):
def execute(self, context):
context["ti"].xcom_push(key="some/key/with/slash", value="slash_value")

task = PushOperator(task_id="push_task")
runtime_ti = create_runtime_ti(task=task, dag_id="test_dag")

# Run the task (which should trigger xcom_push)
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

# Verify supervisor received a SetXCom with quoted key
called_args = [
call.kwargs.get("msg") or call.args[0] for call in mock_supervisor_comms.send.call_args_list
]
assert any(getattr(arg, "key", None) == "some/key/with/slash" for arg in called_args)

ser_value = BaseXCom.serialize_value("slash_value")
mock_supervisor_comms.send.reset_mock()
mock_supervisor_comms.send.return_value = XComSequenceSliceResult(
key="some/key/with/slash",
root=[ser_value],
)

pulled_value = runtime_ti.xcom_pull(key="some/key/with/slash", task_ids="push_task")
assert pulled_value == "slash_value"

expected_key = quote("some/key/with/slash", safe="")
mock_supervisor_comms.send.assert_any_call(
GetXComSequenceSlice(
key=expected_key,
dag_id="test_dag",
run_id="test_run",
task_id="push_task",
map_index=0,
include_prior_dates=False,
start=None,
stop=None,
step=None,
type="GetXComSequenceSlice",
)
)

def test_taskflow_dict_return_with_slash_key(self, create_runtime_ti, mock_supervisor_comms):
"""
High-level: Ensure TaskFlow returning dict with slash in key doesn't 404 during XCom push.
"""

@dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3))
def dag_with_slash_key():
@task_decorator
def dict_task():
return {"key with slash /": "Some Value"}

return dict_task() # returns XComArg

dag_obj = dag_with_slash_key()
task_op = dag_obj.get_task("dict_task")
runtime_ti = create_runtime_ti(task=task_op, dag_id=dag_obj.dag_id)

# Run task instance → should trigger TaskFlow dict expansion + XCom push
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

# Mock supervisor response to simulate retrieval
ser_value = BaseXCom.serialize_value("Some Value")
mock_supervisor_comms.send.reset_mock()
mock_supervisor_comms.send.return_value = XComSequenceSliceResult(
key="key/slash",
root=[ser_value],
)

pulled = runtime_ti.xcom_pull(key="key/slash", task_ids="dict_task")
assert pulled == "Some Value"


class TestXComAfterTaskExecution:
@pytest.mark.parametrize(
Expand Down
Loading