diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8dd47fd5bad7..2701af10c3f1c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -534,7 +534,8 @@ repos: files: > (?x) ^airflow-ctl.*\.py$| - ^providers/fab.*\.py$| + ^airflow-core/src/airflow/models/.*\.py$| + ^providers/fab/.*\.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py index 833b927df1389..30374d5e41a0d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -494,6 +494,7 @@ def wait_dag_run_until_finished( run_id=dag_run_id, interval=interval, result_task_ids=result_task_ids, + session=session, ) return StreamingResponse(waiter.wait()) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 2cdb89fcc88d9..57b83f0dbb3a2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -86,14 +86,13 @@ def get_xcom_entry( task_ids=task_id, dag_ids=dag_id, map_indexes=map_index, - session=session, limit=1, ) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. - result = xcom_query.limit(1).first() + result = session.scalars(xcom_query).first() if result is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found") @@ -249,9 +248,8 @@ def create_xcom_entry( dag_ids=dag_id, run_id=dag_run_id, map_indexes=request_body.map_index, - session=session, ) - result = already_existing_query.with_entities(XComModel.value).first() + result = session.execute(already_existing_query.with_only_columns(XComModel.value)).first() if result: raise HTTPException( status_code=status.HTTP_409_CONFLICT, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py index 259389e799494..37a5d761b3033 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py @@ -26,6 +26,7 @@ import attrs from sqlalchemy import select +from airflow.api_fastapi.common.db.common import SessionDep from airflow.models.dagrun import DagRun from airflow.models.xcom import XCOM_RETURN_KEY, XComModel from airflow.utils.session import create_session_async @@ -43,6 +44,7 @@ class DagRunWaiter: run_id: str interval: float result_task_ids: list[str] | None + session: SessionDep async def _get_dag_run(self) -> DagRun: async with create_session_async() as session: @@ -55,7 +57,7 @@ def _serialize_xcoms(self) -> dict[str, Any]: task_ids=self.result_task_ids, dag_ids=self.dag_id, ) - xcom_query = xcom_query.order_by(XComModel.task_id, XComModel.map_index) + xcom_query = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)).all() def _group_xcoms(g: Iterator[XComModel]) -> Any: entries = list(g) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 576f6c4e0dc72..b2399635499cc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -76,7 +76,6 @@ async def xcom_query( run_id: str, task_id: str, key: str, - session: SessionDep, map_index: Annotated[int | None, Query()] = None, ) -> Select: query = XComModel.get_many( @@ -85,7 +84,6 @@ async def xcom_query( task_ids=task_id, dag_ids=dag_id, map_indexes=map_index, - session=session, ) return query @@ -151,23 +149,22 @@ def get_xcom( task_ids=task_id, dag_ids=dag_id, include_prior_dates=params.include_prior_dates, - session=session, ) if params.offset is not None: - xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None) + xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None) if params.offset >= 0: xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset) else: xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset) else: - xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) + xcom_query = xcom_query.where(XComModel.map_index == params.map_index) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` # (which automatically deserializes using the backend), we avoid potential # performance hits from retrieving large data files into the API server. - result = xcom_query.limit(1).first() + result = session.scalars(xcom_query).first() if result is None: if params.offset is None: message = ( @@ -204,7 +201,6 @@ def get_mapped_xcom_by_index( key=key, task_ids=task_id, dag_ids=dag_id, - session=session, ) xcom_query = xcom_query.order_by(None) if offset >= 0: @@ -212,7 +208,7 @@ def get_mapped_xcom_by_index( else: xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset) - if (result := xcom_query.limit(1).first()) is None: + if (result := session.scalars(xcom_query).first()) is None: message = ( f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" ) @@ -250,7 +246,6 @@ def get_mapped_xcom_by_slice( task_ids=task_id, dag_ids=dag_id, include_prior_dates=params.include_prior_dates, - session=session, ) query = query.order_by(None) @@ -309,7 +304,7 @@ def get_mapped_xcom_by_slice( else: query = query.slice(-stop, -start) - values = [row.value for row in query.with_entities(XComModel.value)] + values = [row.value for row in session.execute(query.with_only_columns(XComModel.value)).all()] if step != 1: values = values[::step] return XComSequenceSliceResponse(values) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 4cb09c9e45941..c87496858243d 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -790,13 +790,12 @@ def fetch_task_instances( def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session): """Check if last N dags failed.""" - dag_runs = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag_id) + dag_runs = session.scalars( + select(DagRun) + .where(DagRun.dag_id == dag_id) .order_by(DagRun.logical_date.desc()) .limit(max_consecutive_failed_dag_runs) - .all() - ) + ).all() """ Marking dag as paused, if needed""" to_be_paused = len(dag_runs) >= max_consecutive_failed_dag_runs and all( dag_run.state == DagRunState.FAILED for dag_run in dag_runs diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index c199d85d43149..f41fe648418a3 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -164,9 +164,10 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) -> try: # Get deadlines which match the provided conditions and their associated DagRuns. - deadline_dagrun_pairs = ( - session.query(Deadline, DagRun).join(DagRun).filter(and_(*filter_conditions)).all() - ) + deadline_dagrun_pairs = session.execute( + select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions)) + ).all() + except AttributeError as e: logger.exception("Error resolving deadlines: %s", e) raise diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index a9c05c71d5109..b6bbf67fae3c2 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -476,8 +476,8 @@ def get_latest_serialized_dags( """ # Subquery to get the latest serdag per dag_id latest_serdag_subquery = ( - session.query(cls.dag_id, func.max(cls.created_at).label("created_at")) - .filter(cls.dag_id.in_(dag_ids)) + select(cls.dag_id, func.max(cls.created_at).label("created_at")) + .where(cls.dag_id.in_(dag_ids)) .group_by(cls.dag_id) .subquery() ) @@ -501,9 +501,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA :returns: a dict of DAGs read from database """ latest_serialized_dag_subquery = ( - session.query(cls.dag_id, func.max(cls.created_at).label("max_created")) - .group_by(cls.dag_id) - .subquery() + select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery() ) serialized_dags = session.scalars( select(cls).join( diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 27a46b366dc6c..77d6a6e97c269 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -264,16 +264,14 @@ def clear_task_instances( for instance in tis: run_ids_by_dag_id[instance.dag_id].add(instance.run_id) - drs = ( - session.query(DagRun) - .filter( + drs = session.scalars( + select(DagRun).where( or_( and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) for dag_id, run_ids in run_ids_by_dag_id.items() ) ) - .all() - ) + ).all() dag_run_state = DagRunState(dag_run_state) # Validate the state value. for dr in drs: if dr.state in State.finished_dr_states: @@ -659,7 +657,7 @@ def get_task_instance( session: Session = NEW_SESSION, ) -> TaskInstance | None: query = ( - session.query(TaskInstance) + select(TaskInstance) .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it .filter_by( dag_id=dag_id, @@ -672,9 +670,9 @@ def get_task_instance( if lock_for_update: for attempt in run_with_db_retries(logger=cls.logger()): with attempt: - return query.with_for_update().one_or_none() + return session.execute(query.with_for_update()).scalar_one_or_none() else: - return query.one_or_none() + return session.execute(query).scalar_one_or_none() return None @@ -824,13 +822,13 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: if not task.downstream_task_ids: return True - ti = session.query(func.count(TaskInstance.task_id)).filter( + ti = select(func.count(TaskInstance.task_id)).where( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(task.downstream_task_ids), TaskInstance.run_id == self.run_id, TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)), ) - count = ti[0][0] + count = session.scalar(ti) return count == len(task.downstream_task_ids) @provide_session @@ -1005,7 +1003,9 @@ def ready_for_retry(self) -> bool: def _get_dagrun(dag_id, run_id, session) -> DagRun: from airflow.models.dagrun import DagRun # Avoid circular import - dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one() + dr = session.execute( + select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) + ).scalar_one() return dr @provide_session @@ -1947,7 +1947,6 @@ def xcom_pull( task_ids=task_ids, map_indexes=map_indexes, include_prior_dates=include_prior_dates, - session=session, ) # NOTE: Since we're only fetching the value field and not the whole @@ -1956,8 +1955,14 @@ def xcom_pull( # We are only pulling one single task. if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): - first = query.with_entities( - XComModel.run_id, XComModel.task_id, XComModel.dag_id, XComModel.map_index, XComModel.value + first = session.execute( + query.with_only_columns( + XComModel.run_id, + XComModel.task_id, + XComModel.dag_id, + XComModel.map_index, + XComModel.value, + ) ).first() if first is None: # No matching XCom at all. return default @@ -1998,16 +2003,20 @@ def xcom_pull( def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int: """Return Number of running TIs from the DB.""" # .count() is inefficient - num_running_task_instances_query = session.query(func.count()).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.state == TaskInstanceState.RUNNING, + num_running_task_instances_query = ( + select(func.count()) + .select_from(TaskInstance) + .where( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == TaskInstanceState.RUNNING, + ) ) if same_dagrun: - num_running_task_instances_query = num_running_task_instances_query.filter( + num_running_task_instances_query = num_running_task_instances_query.where( TaskInstance.run_id == self.run_id ) - return num_running_task_instances_query.scalar() + return session.scalar(num_running_task_instances_query) @staticmethod def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index e9fc2ac2bfc9a..709d6bf69030c 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -37,7 +37,7 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Query, relationship +from sqlalchemy.orm import relationship from airflow._shared.timezones import timezone from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies @@ -144,11 +144,11 @@ def clear( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) + query = select(cls).where(cls.dag_id == dag_id, cls.task_id == task_id, cls.run_id == run_id) if map_index is not None: - query = query.filter_by(map_index=map_index) + query = query.where(cls.map_index == map_index) - for xcom in query: + for xcom in session.scalars(query): # print(f"Clearing XCOM {xcom} with value {xcom.value}") session.delete(xcom) @@ -188,7 +188,7 @@ def set( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + dag_run_id = session.scalar(select(DagRun.id).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) if dag_run_id is None: raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") @@ -246,7 +246,6 @@ def set( session.flush() @classmethod - @provide_session def get_many( cls, *, @@ -257,8 +256,7 @@ def get_many( map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, limit: int | None = None, - session: Session = NEW_SESSION, - ) -> Query: + ) -> Select: """ Composes a query to get one or more XCom entries. @@ -289,42 +287,42 @@ def get_many( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - query = session.query(cls).join(XComModel.dag_run) + query = select(cls).join(XComModel.dag_run) if key: - query = query.filter(XComModel.key == key) + query = query.where(XComModel.key == key) if is_container(task_ids): - query = query.filter(cls.task_id.in_(task_ids)) + query = query.where(cls.task_id.in_(task_ids)) elif task_ids is not None: - query = query.filter(cls.task_id == task_ids) + query = query.where(cls.task_id == task_ids) if is_container(dag_ids): - query = query.filter(cls.dag_id.in_(dag_ids)) + query = query.where(cls.dag_id.in_(dag_ids)) elif dag_ids is not None: - query = query.filter(cls.dag_id == dag_ids) + query = query.where(cls.dag_id == dag_ids) if isinstance(map_indexes, range) and map_indexes.step == 1: - query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) + query = query.where(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) elif is_container(map_indexes): - query = query.filter(cls.map_index.in_(map_indexes)) + query = query.where(cls.map_index.in_(map_indexes)) elif map_indexes is not None: - query = query.filter(cls.map_index == map_indexes) + query = query.where(cls.map_index == map_indexes) if include_prior_dates: dr = ( - session.query( + select( func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after") ) - .filter(DagRun.run_id == run_id) + .where(DagRun.run_id == run_id) .subquery() ) - query = query.filter( + query = query.where( func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after ) else: - query = query.filter(cls.run_id == run_id) + query = query.where(cls.run_id == run_id) query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) if limit: diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index b930b921c9788..d6594c5c544fa 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -99,7 +99,7 @@ from airflow.utils.docs import get_docs_url from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string, qualname -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunTriggeredByType, DagRunType @@ -3592,13 +3592,16 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: self.log.info( "Attempting to retrieve link from XComs with key: %s for task id: %s", self.xcom_key, ti_key ) - value = XComModel.get_many( - key=self.xcom_key, - run_id=ti_key.run_id, - dag_ids=ti_key.dag_id, - task_ids=ti_key.task_id, - map_indexes=ti_key.map_index, - ).first() + with create_session() as session: + value = session.execute( + XComModel.get_many( + key=self.xcom_key, + run_id=ti_key.run_id, + dag_ids=ti_key.dag_id, + task_ids=ti_key.task_id, + map_indexes=ti_key.map_index, + ).with_only_columns(XComModel.value) + ).first() if not value: self.log.debug( "No link with name: %s present in XCom as key: %s, returning empty link", diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index f9408c808f2b2..cf8f80642f7ca 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -408,12 +408,13 @@ def test_xcom_round_trip(self, client, create_task_instance, session, orig_value assert response.status_code == 201 - stored_value = XComModel.get_many( - key="xcom_1", - dag_ids=ti.dag_id, - task_ids=ti.task_id, - run_id=ti.run_id, - session=session, + stored_value = session.execute( + XComModel.get_many( + key="xcom_1", + dag_ids=ti.dag_id, + task_ids=ti.task_id, + run_id=ti.run_id, + ).with_only_columns(XComModel.value) ).first() deserialized_value = XComModel.deserialize_value(stored_value) diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 73b68d76d0c45..1412b16bdbbaa 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -127,17 +127,13 @@ def test_prune_deadlines(self, mock_session, conditions, dagrun): # Set up the query chain to return a list of (Deadline, DagRun) pairs mock_dagrun = mock.Mock(spec=DagRun, end_date=datetime.now()) mock_deadline = mock.Mock(spec=Deadline, deadline_time=mock_dagrun.end_date + timedelta(days=365)) - mock_query = mock_session.query.return_value - mock_query.join.return_value = mock_query - mock_query.filter.return_value = mock_query + mock_query = mock_session.execute.return_value mock_query.all.return_value = [(mock_deadline, mock_dagrun)] if conditions else [] result = Deadline.prune_deadlines(conditions=conditions, session=mock_session) - assert result == expected_result if conditions: - mock_session.query.assert_called_once_with(Deadline, DagRun) - mock_session.query.return_value.filter.assert_called_once() # Assert that the conditions are applied. + mock_session.execute.return_value.all.assert_called_once() mock_session.delete.assert_called_once_with(mock_deadline) else: mock_session.query.assert_not_called() diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index d160ee4585cb3..69a4ba5847f90 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -271,7 +271,9 @@ def test_submit_event_task_end(mock_utcnow, session, create_task_instance, event session.commit() def get_xcoms(ti): - return XComModel.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id).all() + return session.scalars( + XComModel.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id) + ).all() # now for the real test # first check initial state diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 37dc65dfc6508..edfdf7f3bd5b1 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -210,12 +210,13 @@ def setup_for_xcom_get_one(self, task_instance, push_simple_json_xcom): @pytest.mark.usefixtures("setup_for_xcom_get_one") def test_xcom_get_one(self, session, task_instance): - stored_value = XComModel.get_many( - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + stored_value = session.execute( + XComModel.get_many( + key="xcom_1", + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) ).first() assert XComModel.deserialize_value(stored_value) == {"key": "value"} @@ -256,13 +257,14 @@ def tis_for_xcom_get_one_from_prior_date_without_logical_date( def test_xcom_get_one_from_prior_date(self, session, tis_for_xcom_get_one_from_prior_date): _, ti2 = tis_for_xcom_get_one_from_prior_date - retrieved_value = XComModel.get_many( - run_id=ti2.run_id, - key="xcom_1", - task_ids="task_1", - dag_ids="dag", - include_prior_dates=True, - session=session, + retrieved_value = session.execute( + XComModel.get_many( + run_id=ti2.run_id, + key="xcom_1", + task_ids="task_1", + dag_ids="dag", + include_prior_dates=True, + ).with_only_columns(XComModel.value) ).first() assert XComModel.deserialize_value(retrieved_value) == {"key": "value"} @@ -270,13 +272,14 @@ def test_xcom_get_one_from_prior_date_with_no_logical_dates( self, session, tis_for_xcom_get_one_from_prior_date_without_logical_date ): _, ti2 = tis_for_xcom_get_one_from_prior_date_without_logical_date - retrieved_value = XComModel.get_many( - run_id=ti2.run_id, - key="xcom_1", - task_ids="task_1", - dag_ids="dag", - include_prior_dates=True, - session=session, + retrieved_value = session.execute( + XComModel.get_many( + run_id=ti2.run_id, + key="xcom_1", + task_ids="task_1", + dag_ids="dag", + include_prior_dates=True, + ).with_only_columns(XComModel.value) ).first() assert XComModel.deserialize_value(retrieved_value) == {"key": "value"} @@ -286,12 +289,13 @@ def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simp @pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value") def test_xcom_get_many_single_argument_value(self, session, task_instance): - stored_xcoms = XComModel.get_many( - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + stored_xcoms = session.scalars( + XComModel.get_many( + key="xcom_1", + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ) ).all() assert len(stored_xcoms) == 1 assert stored_xcoms[0].key == "xcom_1" @@ -305,13 +309,14 @@ def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_jso @pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks") def test_xcom_get_many_multiple_tasks(self, session, task_instance): - stored_xcoms = XComModel.get_many( - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=["task_1", "task_2"], - run_id=task_instance.run_id, - session=session, - ) + stored_xcoms = session.scalars( + XComModel.get_many( + key="xcom_1", + dag_ids=task_instance.dag_id, + task_ids=["task_1", "task_2"], + run_id=task_instance.run_id, + ) + ).all() sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))] assert sorted_values == [json.dumps({"key1": "value1"}), json.dumps({"key2": "value2"})] @@ -328,14 +333,15 @@ def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_sim def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_from_prior_dates): ti1, ti2 = tis_for_xcom_get_many_from_prior_dates session.add(ti1) # for some reason, ti1 goes out of the session scope - stored_xcoms = XComModel.get_many( - run_id=ti2.run_id, - key="xcom_1", - dag_ids="dag", - task_ids="task_1", - include_prior_dates=True, - session=session, - ) + stored_xcoms = session.scalars( + XComModel.get_many( + run_id=ti2.run_id, + key="xcom_1", + dag_ids="dag", + task_ids="task_1", + include_prior_dates=True, + ) + ).all() # The retrieved XComs should be ordered by logical date, latest first. assert [x.value for x in stored_xcoms] == list( @@ -351,7 +357,6 @@ def test_xcom_get_invalid_key(self, session, task_instance): dag_ids=task_instance.dag_id, task_ids=task_instance.task_id, run_id=task_instance.run_id, - session=session, ) @@ -472,12 +477,13 @@ def test_xcom_round_trip(self, value, expected_value, push_simple_json_xcom, tas """Test that XComModel serialization and deserialization work as expected.""" push_simple_json_xcom(ti=task_instance, key="xcom_1", value=value) - stored_value = XComModel.get_many( - key="xcom_1", - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + stored_value = session.execute( + XComModel.get_many( + key="xcom_1", + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) ).first() deserialized_value = XComModel.deserialize_value(stored_value) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index adca2955a30a3..e1369aefa0cea 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -52,9 +52,9 @@ from tests_common.test_utils import db from tests_common.test_utils.dag import sync_dag_to_db -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS -if AIRFLOW_V_3_0_PLUS: +if AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_1_PLUS: from airflow.models.xcom import XComModel as XCom else: from airflow.models.xcom import XCom # type: ignore[no-redef] @@ -1463,11 +1463,23 @@ def test_push_xcom_pod_info( pod, _ = self.run_pod(k) if AIRFLOW_V_3_0_PLUS: - pod_name = XCom.get_many(run_id=self.dag_run.run_id, task_ids="task", key="pod_name").first() - pod_namespace = XCom.get_many( - run_id=self.dag_run.run_id, task_ids="task", key="pod_namespace" - ).first() - + if AIRFLOW_V_3_1_PLUS: + with create_session() as session: + pod_name = session.execute( + XCom.get_many( + run_id=self.dag_run.run_id, task_ids="task", key="pod_name" + ).with_only_columns(XCom.value) + ).first() + pod_namespace = session.execute( + XCom.get_many( + run_id=self.dag_run.run_id, task_ids="task", key="pod_namespace" + ).with_only_columns(XCom.value) + ).first() + else: + pod_name = XCom.get_many(run_id=self.dag_run.run_id, task_ids="task", key="pod_name").first() + pod_namespace = XCom.get_many( + run_id=self.dag_run.run_id, task_ids="task", key="pod_namespace" + ).first() pod_name = XCom.deserialize_value(pod_name) pod_namespace = XCom.deserialize_value(pod_namespace) else: diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py b/providers/common/io/tests/unit/common/io/xcom/test_backend.py index 72f868c62fcd3..5aca935e8d5e8 100644 --- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py +++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py @@ -24,14 +24,20 @@ import airflow.models.xcom from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils import timezone + +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from tests_common.test_utils import db from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, XCOM_RETURN_KEY +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, XCOM_RETURN_KEY pytestmark = [pytest.mark.db_test] +if AIRFLOW_V_3_1_PLUS: + from airflow.models.xcom import XComModel if AIRFLOW_V_3_0_PLUS: from airflow.models.xcom import XComModel from airflow.sdk import ObjectStoragePath @@ -137,18 +143,27 @@ def test_value_storage(self, task_instance, mock_supervisor_comms, session): task_id=task_instance.task_id, run_id=task_instance.run_id, ) - - res = ( - XComModel.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + if AIRFLOW_V_3_1_PLUS: + res = session.execute( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) + ).first() + else: + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() ) - .with_entities(XComModel.value) - .first() - ) data = XComModel.deserialize_value(res) else: res = ( @@ -179,14 +194,28 @@ def test_value_storage(self, task_instance, mock_supervisor_comms, session): assert value == {"key": "bigvaluebigvaluebigvalue" * 100} if AIRFLOW_V_3_0_PLUS: - qry = XComModel.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, - ) - assert str(p) == XComModel.deserialize_value(qry.first()) + if AIRFLOW_V_3_1_PLUS: + value = session.execute( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) + ).first() + else: + value = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() + ) + assert str(p) == XComModel.deserialize_value(value) else: qry = XCom.get_many( key=XCOM_RETURN_KEY, @@ -225,18 +254,27 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): task_id=task_instance.task_id, run_id=task_instance.run_id, ) - - res = ( - XComModel.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + if AIRFLOW_V_3_1_PLUS: + res = session.execute( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) + ).first() + else: + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() ) - .with_entities(XComModel.value) - .first() - ) data = XComModel.deserialize_value(res) else: res = ( @@ -279,17 +317,27 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): run_id=task_instance.run_id, map_index=task_instance.map_index, ) - value = ( - XComModel.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + if AIRFLOW_V_3_1_PLUS: + value = session.execute( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) + ).first() + else: + value = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() ) - .with_entities(XComModel.value) - .first() - ) else: XCom.clear( dag_id=task_instance.dag_id, @@ -330,18 +378,27 @@ def test_compression(self, task_instance, session, mock_supervisor_comms): task_id=task_instance.task_id, run_id=task_instance.run_id, ) - - res = ( - XComModel.get_many( - key=XCOM_RETURN_KEY, - dag_ids=task_instance.dag_id, - task_ids=task_instance.task_id, - run_id=task_instance.run_id, - session=session, + if AIRFLOW_V_3_1_PLUS: + res = session.execute( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + ).with_only_columns(XComModel.value) + ).first() + else: + res = ( + XComModel.get_many( + key=XCOM_RETURN_KEY, + dag_ids=task_instance.dag_id, + task_ids=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + .with_entities(XComModel.value) + .first() ) - .with_entities(XComModel.value) - .first() - ) data = XComModel.deserialize_value(res) else: res = (