diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 995e255cf363d..a1cd1440d4bd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -227,6 +227,14 @@ repos: files: ^airflow-core/src/airflow/models/taskinstance\.py$|^airflow-core/src/airflow/models/taskinstancehistory\.py$ pass_filenames: false require_serial: true + - id: prevent-usage-of-session.query + name: Prevent usage of session.query + entry: ./scripts/ci/pre_commit/usage_session_query.py + language: python + additional_dependencies: ['rich>=12.4.4'] + files: ^airflow.*\.py$|^task_sdk.*\.py + exclude: ^task_sdk/tests/.*\.py$ + pass_filenames: true - id: check-deferrable-default name: Check and fix default value of default_deferrable language: python diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index e22d0a5f34d0d..0d4b6808868f7 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -246,11 +246,9 @@ def ti_run( xcom_keys = list(session.scalars(query)) task_reschedule_count = ( - session.query( - func.count(TaskReschedule.id) # or any other primary key column - ) - .filter(TaskReschedule.ti_id == ti_id_str) - .scalar() + session.execute( + select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str) + ).scalar() or 0 ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index a9ed4a5b48d20..0fbd091606f98 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -400,7 +400,7 @@ def set_xcom( if not run_id: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found") - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + dag_run_id = session.scalar(DagRun.id).where(dag_id=dag_id, run_id=run_id) if dag_run_id is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}") diff --git a/airflow-core/src/airflow/dag_processing/bundles/manager.py b/airflow-core/src/airflow/dag_processing/bundles/manager.py index a3538f1e29191..d465d5724c5a3 100644 --- a/airflow-core/src/airflow/dag_processing/bundles/manager.py +++ b/airflow-core/src/airflow/dag_processing/bundles/manager.py @@ -19,6 +19,8 @@ from typing import TYPE_CHECKING from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy import select, update from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -124,8 +126,9 @@ def parse_config(self) -> None: @provide_session def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None: self.log.debug("Syncing DAG bundles to the database") - stored = {b.name: b for b in session.query(DagBundleModel).all()} - for name in self._bundle_config.keys(): + stored = {b.name: b for b in session.scalars(select(DagBundleModel)).all()} + active_bundle_names = set(self._bundle_config.keys()) + for name in active_bundle_names: if bundle := stored.pop(name, None): bundle.active = True else: diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 0cc71bf19b482..71d404832c54e 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -2028,9 +2028,8 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE We can then use this information to determine whether to reschedule a task or fail it. """ - return ( - session.query(Log) - .where( + return session.execute( + select(func.count(Log.id)).where( Log.task_id == ti.task_id, Log.dag_id == ti.dag_id, Log.run_id == ti.run_id, @@ -2038,24 +2037,22 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE Log.try_number == ti.try_number, Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT, ) - .count() - ) + ).scalar() previous_ti_running_metrics: dict[tuple[str, str, str], int] = {} @provide_session def _emit_running_ti_metrics(self, session: Session = NEW_SESSION) -> None: - running = ( - session.query( + running = session.execute( + select( TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.queue, func.count(TaskInstance.task_id).label("running_count"), ) - .filter(TaskInstance.state == State.RUNNING) + .where(TaskInstance.state == State.RUNNING) .group_by(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.queue) - .all() - ) + ).all() ti_running_metrics = {(row.dag_id, row.task_id, row.queue): row.running_count for row in running} diff --git a/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py b/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py index 08843afb0e3b2..d929390a53b3f 100644 --- a/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py +++ b/airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py @@ -60,7 +60,7 @@ def upgrade(): if not context.is_offline_mode(): session = get_session() try: - for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))): trigger.kwargs = trigger.kwargs session.commit() finally: @@ -81,7 +81,7 @@ def downgrade(): else: session = get_session() try: - for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))): trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs)) session.commit() finally: diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 8dcd00fef0cbc..653cd7a1e66b2 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -788,10 +788,13 @@ def fetch_task_instances( def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session): """Check if last N dags failed.""" dag_runs = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag_id) - .order_by(DagRun.logical_date.desc()) - .limit(max_consecutive_failed_dag_runs) + session.execute( + select(DagRun) + .where(DagRun.dag_id == dag_id) + .order_by(DagRun.logical_date.desc()) + .limit(max_consecutive_failed_dag_runs) + ) + .scalars() .all() ) """ Marking dag as paused, if needed""" diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 71722b54adee6..ab15ebe21c905 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -479,8 +479,8 @@ def get_latest_serialized_dags( """ # Subquery to get the latest serdag per dag_id latest_serdag_subquery = ( - session.query(cls.dag_id, func.max(cls.created_at).label("created_at")) - .filter(cls.dag_id.in_(dag_ids)) + select(cls.dag_id, func.max(cls.created_at).label("created_at")) + .where(cls.dag_id.in_(dag_ids)) .group_by(cls.dag_id) .subquery() ) @@ -504,9 +504,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA :returns: a dict of DAGs read from database """ latest_serialized_dag_subquery = ( - session.query(cls.dag_id, func.max(cls.created_at).label("max_created")) - .group_by(cls.dag_id) - .subquery() + select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery() ) serialized_dags = session.scalars( select(cls).join( diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 21e20b32b9627..640e9f693641d 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -280,16 +280,14 @@ def clear_task_instances( for instance in tis: run_ids_by_dag_id[instance.dag_id].add(instance.run_id) - drs = ( - session.query(DagRun) - .filter( + drs = session.scalars( + select(DagRun).where( or_( and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) for dag_id, run_ids in run_ids_by_dag_id.items() ) ) - .all() - ) + ).all() dag_run_state = DagRunState(dag_run_state) # Validate the state value. for dr in drs: if dr.state in State.finished_dr_states: @@ -804,22 +802,22 @@ def get_task_instance( session: Session = NEW_SESSION, ) -> TaskInstance | None: query = ( - session.query(TaskInstance) - .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it - .filter_by( - dag_id=dag_id, - run_id=run_id, - task_id=task_id, - map_index=map_index, + select(TaskInstance) + .options(lazyload(TaskInstance.dag_run)) + .where( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, + TaskInstance.task_id == task_id, + TaskInstance.map_index == map_index, ) ) if lock_for_update: for attempt in run_with_db_retries(logger=cls.logger()): with attempt: - return query.with_for_update().one_or_none() + return session.execute(query.with_for_update()).one_or_none() else: - return query.one_or_none() + return session.execute(query).one_or_none() return None @@ -969,13 +967,15 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: if not task.downstream_task_ids: return True - ti = session.query(func.count(TaskInstance.task_id)).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id.in_(task.downstream_task_ids), - TaskInstance.run_id == self.run_id, - TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)), + ti = session.execute( + select(func.count(TaskInstance.task_id)).where( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id.in_(task.downstream_task_ids), + TaskInstance.run_id == self.run_id, + TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)), + ) ) - count = ti[0][0] + count = ti.scalar() return count == len(task.downstream_task_ids) @provide_session @@ -1157,7 +1157,7 @@ def ready_for_retry(self) -> bool: def _get_dagrun(dag_id, run_id, session) -> DagRun: from airflow.models.dagrun import DagRun # Avoid circular import - dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one() + dr = session.execute(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)).one() return dr @provide_session @@ -2209,16 +2209,19 @@ def xcom_pull( def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int: """Return Number of running TIs from the DB.""" # .count() is inefficient - num_running_task_instances_query = session.query(func.count()).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.state == TaskInstanceState.RUNNING, + + num_running_task_instances_query = select( + func.count().where( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == TaskInstanceState.RUNNING, + ) ) if same_dagrun: - num_running_task_instances_query = num_running_task_instances_query.filter( + num_running_task_instances_query = num_running_task_instances_query.where( TaskInstance.run_id == self.run_id ) - return num_running_task_instances_query.scalar() + return session.execute(num_running_task_instances_query).scalar() @staticmethod def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 5efb9414d1a78..511e51de61093 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -142,11 +142,11 @@ def clear( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) + query = select(cls).where(dag_id=dag_id, task_id=task_id, run_id=run_id) if map_index is not None: - query = query.filter_by(map_index=map_index) + query = query.where(map_index=map_index) - for xcom in query: + for xcom in session.scalars(query): # print(f"Clearing XCOM {xcom} with value {xcom.value}") session.delete(xcom) @@ -186,7 +186,9 @@ def set( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() + dag_run_id = session.execute( + select(DagRun.id).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) + ).scalar() if dag_run_id is None: raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") @@ -287,42 +289,43 @@ def get_many( if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") - query = session.query(cls).join(XComModel.dag_run) + query = select(cls).join(XComModel.dag_run) if key: - query = query.filter(XComModel.key == key) + query = query.where(XComModel.key == key) if is_container(task_ids): - query = query.filter(cls.task_id.in_(task_ids)) + query = query.where(cls.task_id.in_(task_ids)) elif task_ids is not None: - query = query.filter(cls.task_id == task_ids) + query = query.where(cls.task_id == task_ids) if is_container(dag_ids): - query = query.filter(cls.dag_id.in_(dag_ids)) + query = query.where(cls.dag_id.in_(dag_ids)) + elif dag_ids is not None: - query = query.filter(cls.dag_id == dag_ids) + query = query.where(cls.dag_id == dag_ids) if isinstance(map_indexes, range) and map_indexes.step == 1: - query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) + query = query.where(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) elif is_container(map_indexes): - query = query.filter(cls.map_index.in_(map_indexes)) + query = query.where(cls.map_index.in_(map_indexes)) elif map_indexes is not None: - query = query.filter(cls.map_index == map_indexes) + query = query.where(cls.map_index == map_indexes) if include_prior_dates: dr = ( - session.query( + select( func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after") ) - .filter(DagRun.run_id == run_id) + .where(DagRun.run_id == run_id) .subquery() ) - query = query.filter( + query = query.where( func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after ) else: - query = query.filter(cls.run_id == run_id) + query = query.where(cls.run_id == run_id) query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) if limit: diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 15339c00cbb69..a7dda96426125 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -32,6 +32,7 @@ import pendulum from pydantic import BaseModel, ConfigDict, ValidationError +from sqlalchemy import select from airflow.configuration import conf from airflow.executors.executor_loader import ExecutorLoader @@ -179,6 +180,30 @@ def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage] last = msg +def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: + """ + Given TI | TIKey, return a TI object. + + Will raise exception if no TI is found in the database. + """ + from airflow.models.taskinstance import TaskInstance + + if isinstance(ti, TaskInstance): + return ti + val = session.execute( + select(TaskInstance).where( + TaskInstance.task_id == ti.task_id, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + ).one_or_none + if not val: + raise AirflowException(f"Could not find TaskInstance for {ti}") + val.try_number = ti.try_number + return val + + class FileTaskHandler(logging.Handler): """ FileTaskHandler is a python log handler that handles and reads task instance logs. diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index d1b8a7aa4ee62..5646ff06b0f03 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -352,6 +352,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------+---------+ | pretty-format-json | Format JSON files | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ +| prevent-usage-of-session.query | Prevent usage of session.query | | ++-----------------------------------------------------------+--------------------------------------------------------+---------+ | pylint | pylint | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ | python-no-log-warn | Check if there are no deprecate log warn | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index beaf498a9e529..18dff1a3f0ca2 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -344,6 +344,7 @@ check-docstring-param-types | check-example-dags-urls |                           check-executables-have-shebangs | check-extra-packages-references |               check-extras-order | check-fab-migrations | check-for-inclusive-language |        +<<<<<<< HEAD check-get-lineage-collector-providers | check-hooks-apply | check-i18n-json |     check-imports-in-providers | check-incorrect-use-of-LoggingMixin |                check-init-decorator-arguments | check-integrations-list-consistent |             @@ -414,6 +415,76 @@ --verbose-vPrint verbose information about performed steps. --help-hShow this message and exit. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +======= +check-get-lineage-collector-providers | check-hatch-build-order |                 +check-hooks-apply | check-imports-in-providers |                                  +check-incorrect-use-of-LoggingMixin | check-init-decorator-arguments |            +check-integrations-list-consistent | check-lazy-logging |                         +check-links-to-example-dags-do-not-use-hardcoded-versions | check-merge-conflict  +| check-min-python-version | check-newsfragments-are-valid |                      +check-no-airflow-deprecation-in-providers | check-no-providers-in-core-examples | +check-only-new-session-with-provide-session |                                     +check-persist-credentials-disabled-in-github-workflows |                          +check-pre-commit-information-consistent | check-provide-create-sessions-imports | +check-provider-docs-valid | check-provider-yaml-valid |                           +check-providers-subpackages-init-file-exist | check-pydevd-left-in-code |         +check-pyproject-toml-consistency | check-revision-heads-map |                     +check-safe-filter-usage-in-html | check-significant-newsfragments-are-valid |     +check-sql-dependency-common-data-structure |                                      +check-start-date-not-used-in-defaults | check-system-tests-present |              +check-system-tests-tocs | check-taskinstance-tis-attrs |                          +check-template-context-variable-in-sync | check-template-fields-valid |           +check-tests-in-the-right-folders | check-tests-unittest-testcase |                +check-urlparse-usage-in-code | check-xml | check-zip-file-is-not-committed |      +codespell | compile-fab-assets | compile-ui-assets | compile-ui-assets-dev |      +create-missing-init-py-files-tests | debug-statements | detect-private-key |      +doctoc | end-of-file-fixer | fix-encoding-pragma | flynt |                        +generate-airflow-diagrams | generate-openapi-spec | generate-pypi-readme |        +generate-tasksdk-datamodels | generate-volumes-for-sources | identity |           +insert-license | kubeconform | lint-chart-schema | lint-dockerfile |              +lint-helm-chart | lint-json-schema | lint-markdown | mixed-line-ending |          +mypy-airflow | mypy-dev | mypy-docs | mypy-providers | mypy-task-sdk |            +pretty-format-json | prevent-usage-of-session.query | pylint | python-no-log-warn +| replace-bad-characters | rst-backticks | ruff | ruff-format | shellcheck |      +trailing-whitespace | ts-compile-format-lint-ui | update-black-version |          +update-breeze-cmd-output | update-breeze-readme-config-hash |                     +update-chart-dependencies | update-er-diagram | update-extras |                   +update-in-the-wild-to-be-sorted | update-inlined-dockerfile-scripts |             +update-installed-providers-to-be-sorted | update-installers-and-pre-commit |      +update-local-yml-file | update-migration-references |                             +update-providers-build-files | update-providers-dependencies |                    +update-reproducible-source-date-epoch | update-spelling-wordlist-to-be-sorted |   +update-supported-versions | update-vendored-in-k8s-json-schema | update-version | +validate-operators-init | yamllint | zizmor)                                      +--show-diff-on-failure-sShow diff for files modified by the checks. +--initialize-environmentInitialize environment before running checks. +--max-initialization-attemptsMaximum number of attempts to initialize environment before giving up. +(INTEGER RANGE)                                                        +[default: 3; 1<=x<=10]                                                 +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Selecting files to run the checks on ───────────────────────────────────────────────────────────────────────────────╮ +--file-fList of files to run the checks on.(PATH) +--all-files-aRun checks on all files. +--commit-ref-rRun checks for this commit reference only (can be any git commit-ish reference). Mutually     +exclusive with --last-commit.                                                                 +(TEXT)                                                                                        +--last-commit-cRun checks for all files in last commit. Mutually exclusive with --commit-ref. +--only-my-changes-mRun checks for commits belonging to my PR only: for all commits between merge base to `main`  +branch and HEAD of your branch.                                                               +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Building image before running checks ───────────────────────────────────────────────────────────────────────────────╮ +--skip-image-upgrade-checkSkip checking if the CI image is up to date. +--force-buildForce image build no matter if it is determined as needed. +--github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] +--builderBuildx builder used to perform `docker buildx build` commands.(TEXT) +[default: autodetect]                                          +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--verbose-vPrint verbose information about performed steps. +--help-hShow this message and exit. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +>>>>>>> 953c92935c (update errors variable to bool in usage_session_query.py and remove unnecessary exclude files in pre-commit yaml) diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py index de0227170d7fd..5036f648901ab 100644 --- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py +++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py @@ -128,6 +128,7 @@ "mypy-providers", "mypy-task-sdk", "pretty-format-json", + "prevent-usage-of-session.query", "pylint", "python-no-log-warn", "replace-bad-characters", diff --git a/scripts/ci/pre_commit/usage_session_query.py b/scripts/ci/pre_commit/usage_session_query.py new file mode 100755 index 0000000000000..a6357978426e0 --- /dev/null +++ b/scripts/ci/pre_commit/usage_session_query.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import sys +from pathlib import Path + +from rich.console import Console + +console = Console(color_system="standard", width=200) + + +def check_session_query(mod: ast.Module) -> int: + errors = False + for node in ast.walk(mod): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + if ( + node.func.attr == "query" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "session" + ): + console.print( + f"\nUse of legacy `session.query` detected on line {node.lineno}. " + f"\nSQLAlchemy 2.0 deprecates the `Query` object" + f"use the `select()` construct instead." + ) + errors = True + return errors + + +def main(): + for file in sys.argv[1:]: + file_path = Path(file) + ast_module = ast.parse(file_path.read_text(encoding="utf-8"), file) + errors = check_session_query(ast_module) + return 1 if errors else 0 + + +if __name__ == "__main__": + sys.exit(main())