diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 57b83f0dbb3a2..6aa86d09f22e7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -87,7 +87,7 @@ def get_xcom_entry( dag_ids=dag_id, map_indexes=map_index, limit=1, - ) + ).options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead @@ -162,7 +162,7 @@ def get_xcom_entries( query = ( query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id)) .join(DagModel, DR.dag_id == DagModel.dag_id) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) if task_id != "~": @@ -285,7 +285,7 @@ def create_xcom_entry( XComModel.map_index == request_body.map_index, ) .limit(1) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) return XComResponseNative.model_validate(xcom) @@ -326,7 +326,7 @@ def update_xcom_entry( XComModel.map_index == patch_body.map_index, ) .limit(1) - .options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) + .options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model)) ) if not xcom_entry: diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 709d6bf69030c..0d3236eb408af 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -107,7 +107,7 @@ class XComModel(TaskInstanceDependencies): task = relationship( "TaskInstance", viewonly=True, - lazy="selectin", + lazy="noload", ) @classmethod diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index 38ca7f48c73ac..d7a4837ddcc17 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -35,6 +35,7 @@ from airflow.utils.session import provide_session from airflow.utils.types import DagRunType +from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.config import conf_vars from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs, clear_db_xcom @@ -202,9 +203,10 @@ def setup(self, dag_maker) -> None: def test_should_respond_200(self, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID) - response = test_client.get( - f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries" - ) + with assert_queries_count(4): + response = test_client.get( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries" + ) assert response.status_code == 200 response_data = response.json() for xcom_entry in response_data["xcom_entries"]: @@ -243,7 +245,8 @@ def test_should_respond_200_with_tilde(self, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID) self._create_xcom_entries(TEST_DAG_ID_2, run_id, logical_date_parsed, TEST_TASK_ID_2) - response = test_client.get("/dags/~/dagRuns/~/taskInstances/~/xcomEntries") + with assert_queries_count(4): + response = test_client.get("/dags/~/dagRuns/~/taskInstances/~/xcomEntries") assert response.status_code == 200 response_data = response.json() for xcom_entry in response_data["xcom_entries"]: @@ -304,10 +307,11 @@ def test_should_respond_200_with_tilde(self, test_client): def test_should_respond_200_with_map_index(self, map_index, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID, mapped_ti=True) - response = test_client.get( - "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - params={"map_index": map_index} if map_index is not None else None, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={"map_index": map_index} if map_index is not None else None, + ) assert response.status_code == 200 response_data = response.json() @@ -382,10 +386,11 @@ def test_should_respond_200_with_map_index(self, map_index, test_client): ) def test_should_respond_200_with_xcom_key(self, key, expected_entries, test_client): self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, TEST_TASK_ID, mapped_ti=True) - response = test_client.get( - "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - params={"xcom_key_pattern": key} if key is not None else None, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={"xcom_key_pattern": key} if key is not None else None, + ) assert response.status_code == 200 response_data = response.json()