diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 59f314b6540a2..c05ae6246a8d9 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -87,7 +87,7 @@ def get_xcom_entry( dag_ids=dag_id, map_indexes=map_index, limit=1, - ) + ).options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead @@ -162,7 +162,7 @@ def get_xcom_entries( query = ( query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id)) .join(DagModel, DR.dag_id == DagModel.dag_id) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) if task_id != "~": @@ -284,7 +284,7 @@ def create_xcom_entry( XComModel.map_index == request_body.map_index, ) .limit(1) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) return XComResponseNative.model_validate(xcom) @@ -325,7 +325,7 @@ def update_xcom_entry( XComModel.map_index == patch_body.map_index, ) .limit(1) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) if not xcom_entry: diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index b17c7a724482b..a0f3d3501beae 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -109,7 +109,7 @@ class XComModel(TaskInstanceDependencies): task = relationship( "TaskInstance", viewonly=True, - lazy="selectin", + lazy="noload", ) @classmethod diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index cd4f497005a84..13eeea5a17f4b 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -35,6 +35,7 @@ from airflow.utils.session import provide_session from airflow.utils.types import DagRunType +from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.config import conf_vars from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs, clear_db_xcom @@ -218,9 +219,10 @@ def setup(self, dag_maker) -> None: def test_should_respond_200(self, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID) - response = test_client.get( - f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries" - ) + with assert_queries_count(4): + response = test_client.get( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries" + ) assert response.status_code == 200 response_data = response.json() for xcom_entry in response_data["xcom_entries"]: @@ -259,7 +261,8 @@ def test_should_respond_200_with_tilde(self, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID) self._create_xcom_entries(TEST_DAG_ID_2, run_id, logical_date_parsed, TEST_TASK_ID_2) - response = test_client.get("/dags/~/dagRuns/~/taskInstances/~/xcomEntries") + with assert_queries_count(4): + response = test_client.get("/dags/~/dagRuns/~/taskInstances/~/xcomEntries") assert response.status_code == 200 response_data = response.json() for xcom_entry in response_data["xcom_entries"]: @@ -320,10 +323,11 @@ def test_should_respond_200_with_tilde(self, test_client): def test_should_respond_200_with_map_index(self, map_index, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID, mapped_ti=True) - response = test_client.get( - "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - params={"map_index": map_index} if map_index is not None else None, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={"map_index": map_index} if map_index is not None else None, + ) assert response.status_code == 200 response_data = response.json() @@ -398,10 +402,11 @@ def test_should_respond_200_with_map_index(self, map_index, test_client): ) def test_should_respond_200_with_xcom_key(self, key, expected_entries, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID, mapped_ti=True) - response = test_client.get( - "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - params={"xcom_key_pattern": key} if key is not None else None, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={"xcom_key_pattern": key} if key is not None else None, + ) assert response.status_code == 200 response_data = response.json()