diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 995e255cf363d..a1cd1440d4bd0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -227,6 +227,14 @@ repos:
files: ^airflow-core/src/airflow/models/taskinstance\.py$|^airflow-core/src/airflow/models/taskinstancehistory\.py$
pass_filenames: false
require_serial: true
+ - id: prevent-usage-of-session.query
+ name: Prevent usage of session.query
+ entry: ./scripts/ci/pre_commit/usage_session_query.py
+ language: python
+ additional_dependencies: ['rich>=12.4.4']
+ files: ^airflow.*\.py$|^task_sdk.*\.py
+ exclude: ^task_sdk/tests/.*\.py$
+ pass_filenames: true
- id: check-deferrable-default
name: Check and fix default value of default_deferrable
language: python
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index e22d0a5f34d0d..0d4b6808868f7 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -246,11 +246,9 @@ def ti_run(
xcom_keys = list(session.scalars(query))
task_reschedule_count = (
- session.query(
- func.count(TaskReschedule.id) # or any other primary key column
- )
- .filter(TaskReschedule.ti_id == ti_id_str)
- .scalar()
+ session.execute(
+ select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str)
+ ).scalar()
or 0
)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index a9ed4a5b48d20..0fbd091606f98 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -400,7 +400,7 @@ def set_xcom(
if not run_id:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found")
- dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
+ dag_run_id = session.scalar(DagRun.id).where(dag_id=dag_id, run_id=run_id)
if dag_run_id is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}")
diff --git a/airflow-core/src/airflow/dag_processing/bundles/manager.py b/airflow-core/src/airflow/dag_processing/bundles/manager.py
index a3538f1e29191..d465d5724c5a3 100644
--- a/airflow-core/src/airflow/dag_processing/bundles/manager.py
+++ b/airflow-core/src/airflow/dag_processing/bundles/manager.py
@@ -19,6 +19,8 @@
from typing import TYPE_CHECKING
from sqlalchemy import delete
+from sqlalchemy import select
+from sqlalchemy import select, update
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
@@ -124,8 +126,9 @@ def parse_config(self) -> None:
@provide_session
def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
self.log.debug("Syncing DAG bundles to the database")
- stored = {b.name: b for b in session.query(DagBundleModel).all()}
- for name in self._bundle_config.keys():
+ stored = {b.name: b for b in session.scalars(select(DagBundleModel)).all()}
+ active_bundle_names = set(self._bundle_config.keys())
+ for name in active_bundle_names:
if bundle := stored.pop(name, None):
bundle.active = True
else:
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 0cc71bf19b482..71d404832c54e 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -2028,9 +2028,8 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE
We can then use this information to determine whether to reschedule a task or fail it.
"""
- return (
- session.query(Log)
- .where(
+ return session.execute(
+ select(func.count(Log.id)).where(
Log.task_id == ti.task_id,
Log.dag_id == ti.dag_id,
Log.run_id == ti.run_id,
@@ -2038,24 +2037,22 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE
Log.try_number == ti.try_number,
Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT,
)
- .count()
- )
+ ).scalar()
previous_ti_running_metrics: dict[tuple[str, str, str], int] = {}
@provide_session
def _emit_running_ti_metrics(self, session: Session = NEW_SESSION) -> None:
- running = (
- session.query(
+ running = session.execute(
+ select(
TaskInstance.dag_id,
TaskInstance.task_id,
TaskInstance.queue,
func.count(TaskInstance.task_id).label("running_count"),
)
- .filter(TaskInstance.state == State.RUNNING)
+ .where(TaskInstance.state == State.RUNNING)
.group_by(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.queue)
- .all()
- )
+ ).all()
ti_running_metrics = {(row.dag_id, row.task_id, row.queue): row.running_count for row in running}
diff --git a/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py b/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py
index 08843afb0e3b2..d929390a53b3f 100644
--- a/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py
+++ b/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py
@@ -60,7 +60,7 @@ def upgrade():
if not context.is_offline_mode():
session = get_session()
try:
- for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
+ for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
trigger.kwargs = trigger.kwargs
session.commit()
finally:
@@ -81,7 +81,7 @@ def downgrade():
else:
session = get_session()
try:
- for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
+ for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()
finally:
diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py
index 8dcd00fef0cbc..653cd7a1e66b2 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -788,10 +788,13 @@ def fetch_task_instances(
def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
"""Check if last N dags failed."""
dag_runs = (
- session.query(DagRun)
- .filter(DagRun.dag_id == dag_id)
- .order_by(DagRun.logical_date.desc())
- .limit(max_consecutive_failed_dag_runs)
+ session.execute(
+ select(DagRun)
+ .where(DagRun.dag_id == dag_id)
+ .order_by(DagRun.logical_date.desc())
+ .limit(max_consecutive_failed_dag_runs)
+ )
+ .scalars()
.all()
)
""" Marking dag as paused, if needed"""
diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py
index 71722b54adee6..ab15ebe21c905 100644
--- a/airflow-core/src/airflow/models/serialized_dag.py
+++ b/airflow-core/src/airflow/models/serialized_dag.py
@@ -479,8 +479,8 @@ def get_latest_serialized_dags(
"""
# Subquery to get the latest serdag per dag_id
latest_serdag_subquery = (
- session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
- .filter(cls.dag_id.in_(dag_ids))
+ select(cls.dag_id, func.max(cls.created_at).label("created_at"))
+ .where(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
)
@@ -504,9 +504,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA
:returns: a dict of DAGs read from database
"""
latest_serialized_dag_subquery = (
- session.query(cls.dag_id, func.max(cls.created_at).label("max_created"))
- .group_by(cls.dag_id)
- .subquery()
+ select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery()
)
serialized_dags = session.scalars(
select(cls).join(
diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py
index 21e20b32b9627..640e9f693641d 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -280,16 +280,14 @@ def clear_task_instances(
for instance in tis:
run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
- drs = (
- session.query(DagRun)
- .filter(
+ drs = session.scalars(
+ select(DagRun).where(
or_(
and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
for dag_id, run_ids in run_ids_by_dag_id.items()
)
)
- .all()
- )
+ ).all()
dag_run_state = DagRunState(dag_run_state) # Validate the state value.
for dr in drs:
if dr.state in State.finished_dr_states:
@@ -804,22 +802,22 @@ def get_task_instance(
session: Session = NEW_SESSION,
) -> TaskInstance | None:
query = (
- session.query(TaskInstance)
- .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
- .filter_by(
- dag_id=dag_id,
- run_id=run_id,
- task_id=task_id,
- map_index=map_index,
+ select(TaskInstance)
+ .options(lazyload(TaskInstance.dag_run))
+ .where(
+ TaskInstance.dag_id == dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.task_id == task_id,
+ TaskInstance.map_index == map_index,
)
)
if lock_for_update:
for attempt in run_with_db_retries(logger=cls.logger()):
with attempt:
- return query.with_for_update().one_or_none()
+ return session.execute(query.with_for_update()).one_or_none()
else:
- return query.one_or_none()
+ return session.execute(query).one_or_none()
return None
@@ -969,13 +967,15 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
if not task.downstream_task_ids:
return True
- ti = session.query(func.count(TaskInstance.task_id)).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id.in_(task.downstream_task_ids),
- TaskInstance.run_id == self.run_id,
- TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
+ ti = session.execute(
+ select(func.count(TaskInstance.task_id)).where(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id.in_(task.downstream_task_ids),
+ TaskInstance.run_id == self.run_id,
+ TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
+ )
)
- count = ti[0][0]
+ count = ti.scalar()
return count == len(task.downstream_task_ids)
@provide_session
@@ -1157,7 +1157,7 @@ def ready_for_retry(self) -> bool:
def _get_dagrun(dag_id, run_id, session) -> DagRun:
from airflow.models.dagrun import DagRun # Avoid circular import
- dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
+ dr = session.execute(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)).one()
return dr
@provide_session
@@ -2209,16 +2209,19 @@ def xcom_pull(
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
"""Return Number of running TIs from the DB."""
# .count() is inefficient
- num_running_task_instances_query = session.query(func.count()).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.state == TaskInstanceState.RUNNING,
+
+ num_running_task_instances_query = select(
+ func.count().where(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.state == TaskInstanceState.RUNNING,
+ )
)
if same_dagrun:
- num_running_task_instances_query = num_running_task_instances_query.filter(
+ num_running_task_instances_query = num_running_task_instances_query.where(
TaskInstance.run_id == self.run_id
)
- return num_running_task_instances_query.scalar()
+ return session.execute(num_running_task_instances_query).scalar()
@staticmethod
def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py
index 5efb9414d1a78..511e51de61093 100644
--- a/airflow-core/src/airflow/models/xcom.py
+++ b/airflow-core/src/airflow/models/xcom.py
@@ -142,11 +142,11 @@ def clear(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
- query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
+ query = select(cls).where(dag_id=dag_id, task_id=task_id, run_id=run_id)
if map_index is not None:
- query = query.filter_by(map_index=map_index)
+ query = query.where(map_index=map_index)
- for xcom in query:
+ for xcom in session.scalars(query):
# print(f"Clearing XCOM {xcom} with value {xcom.value}")
session.delete(xcom)
@@ -186,7 +186,9 @@ def set(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
- dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
+ dag_run_id = session.execute(
+ select(DagRun.id).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
+ ).scalar()
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
@@ -287,42 +289,43 @@ def get_many(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
- query = session.query(cls).join(XComModel.dag_run)
+ query = select(cls).join(XComModel.dag_run)
if key:
- query = query.filter(XComModel.key == key)
+ query = query.where(XComModel.key == key)
if is_container(task_ids):
- query = query.filter(cls.task_id.in_(task_ids))
+ query = query.where(cls.task_id.in_(task_ids))
elif task_ids is not None:
- query = query.filter(cls.task_id == task_ids)
+ query = query.where(cls.task_id == task_ids)
if is_container(dag_ids):
- query = query.filter(cls.dag_id.in_(dag_ids))
+ query = query.where(cls.dag_id.in_(dag_ids))
+
elif dag_ids is not None:
- query = query.filter(cls.dag_id == dag_ids)
+ query = query.where(cls.dag_id == dag_ids)
if isinstance(map_indexes, range) and map_indexes.step == 1:
- query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
+ query = query.where(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
elif is_container(map_indexes):
- query = query.filter(cls.map_index.in_(map_indexes))
+ query = query.where(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
- query = query.filter(cls.map_index == map_indexes)
+ query = query.where(cls.map_index == map_indexes)
if include_prior_dates:
dr = (
- session.query(
+ select(
func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after")
)
- .filter(DagRun.run_id == run_id)
+ .where(DagRun.run_id == run_id)
.subquery()
)
- query = query.filter(
+ query = query.where(
func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after
)
else:
- query = query.filter(cls.run_id == run_id)
+ query = query.where(cls.run_id == run_id)
query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc())
if limit:
diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py
index 15339c00cbb69..a7dda96426125 100644
--- a/airflow-core/src/airflow/utils/log/file_task_handler.py
+++ b/airflow-core/src/airflow/utils/log/file_task_handler.py
@@ -32,6 +32,7 @@
import pendulum
from pydantic import BaseModel, ConfigDict, ValidationError
+from sqlalchemy import select
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
@@ -179,6 +180,30 @@ def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage]
last = msg
+def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+ """
+ Given TI | TIKey, return a TI object.
+
+ Will raise exception if no TI is found in the database.
+ """
+ from airflow.models.taskinstance import TaskInstance
+
+ if isinstance(ti, TaskInstance):
+ return ti
+ val = session.execute(
+ select(TaskInstance).where(
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
+ )
+ ).one_or_none
+ if not val:
+ raise AirflowException(f"Could not find TaskInstance for {ti}")
+ val.try_number = ti.try_number
+ return val
+
+
class FileTaskHandler(logging.Handler):
"""
FileTaskHandler is a python log handler that handles and reads task instance logs.
diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst
index d1b8a7aa4ee62..5646ff06b0f03 100644
--- a/contributing-docs/08_static_code_checks.rst
+++ b/contributing-docs/08_static_code_checks.rst
@@ -352,6 +352,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| pretty-format-json | Format JSON files | |
+-----------------------------------------------------------+--------------------------------------------------------+---------+
+| prevent-usage-of-session.query | Prevent usage of session.query | |
++-----------------------------------------------------------+--------------------------------------------------------+---------+
| pylint | pylint | |
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| python-no-log-warn | Check if there are no deprecate log warn | |
diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg
index beaf498a9e529..18dff1a3f0ca2 100644
--- a/dev/breeze/doc/images/output_static-checks.svg
+++ b/dev/breeze/doc/images/output_static-checks.svg
@@ -344,6 +344,7 @@
│check-docstring-param-types | check-example-dags-urls | │
│check-executables-have-shebangs | check-extra-packages-references | │
│check-extras-order | check-fab-migrations | check-for-inclusive-language | │
+<<<<<<< HEAD
│check-get-lineage-collector-providers | check-hooks-apply | check-i18n-json | │
│check-imports-in-providers | check-incorrect-use-of-LoggingMixin | │
│check-init-decorator-arguments | check-integrations-list-consistent | │
@@ -414,6 +415,76 @@
│--verbose-vPrint verbose information about performed steps.│
│--help-hShow this message and exit.│
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
+=======
+│check-get-lineage-collector-providers | check-hatch-build-order | │
+│check-hooks-apply | check-imports-in-providers | │
+│check-incorrect-use-of-LoggingMixin | check-init-decorator-arguments | │
+│check-integrations-list-consistent | check-lazy-logging | │
+│check-links-to-example-dags-do-not-use-hardcoded-versions | check-merge-conflict │
+│| check-min-python-version | check-newsfragments-are-valid | │
+│check-no-airflow-deprecation-in-providers | check-no-providers-in-core-examples |│
+│check-only-new-session-with-provide-session | │
+│check-persist-credentials-disabled-in-github-workflows | │
+│check-pre-commit-information-consistent | check-provide-create-sessions-imports |│
+│check-provider-docs-valid | check-provider-yaml-valid | │
+│check-providers-subpackages-init-file-exist | check-pydevd-left-in-code | │
+│check-pyproject-toml-consistency | check-revision-heads-map | │
+│check-safe-filter-usage-in-html | check-significant-newsfragments-are-valid | │
+│check-sql-dependency-common-data-structure | │
+│check-start-date-not-used-in-defaults | check-system-tests-present | │
+│check-system-tests-tocs | check-taskinstance-tis-attrs | │
+│check-template-context-variable-in-sync | check-template-fields-valid | │
+│check-tests-in-the-right-folders | check-tests-unittest-testcase | │
+│check-urlparse-usage-in-code | check-xml | check-zip-file-is-not-committed | │
+│codespell | compile-fab-assets | compile-ui-assets | compile-ui-assets-dev | │
+│create-missing-init-py-files-tests | debug-statements | detect-private-key | │
+│doctoc | end-of-file-fixer | fix-encoding-pragma | flynt | │
+│generate-airflow-diagrams | generate-openapi-spec | generate-pypi-readme | │
+│generate-tasksdk-datamodels | generate-volumes-for-sources | identity | │
+│insert-license | kubeconform | lint-chart-schema | lint-dockerfile | │
+│lint-helm-chart | lint-json-schema | lint-markdown | mixed-line-ending | │
+│mypy-airflow | mypy-dev | mypy-docs | mypy-providers | mypy-task-sdk | │
+│pretty-format-json | prevent-usage-of-session.query | pylint | python-no-log-warn│
+│| replace-bad-characters | rst-backticks | ruff | ruff-format | shellcheck | │
+│trailing-whitespace | ts-compile-format-lint-ui | update-black-version | │
+│update-breeze-cmd-output | update-breeze-readme-config-hash | │
+│update-chart-dependencies | update-er-diagram | update-extras | │
+│update-in-the-wild-to-be-sorted | update-inlined-dockerfile-scripts | │
+│update-installed-providers-to-be-sorted | update-installers-and-pre-commit | │
+│update-local-yml-file | update-migration-references | │
+│update-providers-build-files | update-providers-dependencies | │
+│update-reproducible-source-date-epoch | update-spelling-wordlist-to-be-sorted | │
+│update-supported-versions | update-vendored-in-k8s-json-schema | update-version |│
+│validate-operators-init | yamllint | zizmor) │
+│--show-diff-on-failure-sShow diff for files modified by the checks.│
+│--initialize-environmentInitialize environment before running checks.│
+│--max-initialization-attemptsMaximum number of attempts to initialize environment before giving up.│
+│(INTEGER RANGE) │
+│[default: 3; 1<=x<=10] │
+╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
+╭─ Selecting files to run the checks on ───────────────────────────────────────────────────────────────────────────────╮
+│--file-fList of files to run the checks on.(PATH)│
+│--all-files-aRun checks on all files.│
+│--commit-ref-rRun checks for this commit reference only (can be any git commit-ish reference). Mutually │
+│exclusive with --last-commit. │
+│(TEXT) │
+│--last-commit-cRun checks for all files in last commit. Mutually exclusive with --commit-ref.│
+│--only-my-changes-mRun checks for commits belonging to my PR only: for all commits between merge base to `main` │
+│branch and HEAD of your branch. │
+╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
+╭─ Building image before running checks ───────────────────────────────────────────────────────────────────────────────╮
+│--skip-image-upgrade-checkSkip checking if the CI image is up to date.│
+│--force-buildForce image build no matter if it is determined as needed.│
+│--github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow]│
+│--builderBuildx builder used to perform `docker buildx build` commands.(TEXT)│
+│[default: autodetect] │
+╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
+╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮
+│--dry-run-DIf dry-run is set, commands are only printed, not executed.│
+│--verbose-vPrint verbose information about performed steps.│
+│--help-hShow this message and exit.│
+╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
+>>>>>>> 953c92935c (update errors variable to bool in usage_session_query.py and remove unnecessary exclude files in pre-commit yaml)
diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py
index de0227170d7fd..5036f648901ab 100644
--- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py
+++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py
@@ -128,6 +128,7 @@
"mypy-providers",
"mypy-task-sdk",
"pretty-format-json",
+ "prevent-usage-of-session.query",
"pylint",
"python-no-log-warn",
"replace-bad-characters",
diff --git a/scripts/ci/pre_commit/usage_session_query.py b/scripts/ci/pre_commit/usage_session_query.py
new file mode 100755
index 0000000000000..a6357978426e0
--- /dev/null
+++ b/scripts/ci/pre_commit/usage_session_query.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+from rich.console import Console
+
+console = Console(color_system="standard", width=200)
+
+
+def check_session_query(mod: ast.Module) -> int:
+ errors = False
+ for node in ast.walk(mod):
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
+ if (
+ node.func.attr == "query"
+ and isinstance(node.func.value, ast.Name)
+ and node.func.value.id == "session"
+ ):
+ console.print(
+ f"\nUse of legacy `session.query` detected on line {node.lineno}. "
+ f"\nSQLAlchemy 2.0 deprecates the `Query` object"
+ f"use the `select()` construct instead."
+ )
+ errors = True
+ return errors
+
+
+def main():
+ for file in sys.argv[1:]:
+ file_path = Path(file)
+ ast_module = ast.parse(file_path.read_text(encoding="utf-8"), file)
+ errors = check_session_query(ast_module)
+ return 1 if errors else 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())