diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index b2399635499cc..ce8309341f29e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -385,39 +385,20 @@ def set_xcom( # TODO: Can/should we check if a client _hasn't_ provided this for an upstream of a mapped task? That # means loading the serialized dag and that seems like a relatively costly operation for minimal benefit # (the mapped task would fail in a moment as it can't be expanded anyway.) - from airflow.models.dagrun import DagRun - - if not run_id: - raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found") - - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() - if dag_run_id is None: - raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}") - - # Remove duplicate XComs and insert a new one. - session.execute( - delete(XComModel).where( - XComModel.key == key, - XComModel.run_id == run_id, - XComModel.task_id == task_id, - XComModel.dag_id == dag_id, - XComModel.map_index == map_index, - ) - ) - try: # We expect serialised value from the caller - sdk, do not serialise in here - new = XComModel( - dag_run_id=dag_run_id, + XComModel.set( key=key, value=value, run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index, + serialize=False, + session=session, ) - session.add(new) - session.flush() + except ValueError as e: + raise HTTPException(status.HTTP_404_NOT_FOUND, str(e)) except TypeError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index a0f3d3501beae..2729dadda6149 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -167,6 +167,7 @@ def set( task_id: str, run_id: str, map_index: int = -1, + serialize: bool = True, session: Session = NEW_SESSION, ) -> None: """ @@ -178,7 +179,8 @@ def set( :param task_id: Task ID. :param run_id: DAG run ID for the task. :param map_index: Optional map index to assign XCom for a mapped task. - The default is ``-1`` (set for a non-mapped task). + :param serialize: Optional parameter to specify if value should be serialized or not. + The default is ``True``. :param session: Database session. If not given, a new session will be created for this function. """ @@ -215,14 +217,15 @@ def set( ) value = list(value) - value = cls.serialize_value( - value=value, - key=key, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, - ) + if serialize: + value = cls.serialize_value( + value=value, + key=key, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + ) # Remove duplicate XComs and insert a new one. session.execute(