From 6ffe3ac79ae4b9754eb89a1b1665756e161e9757 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 28 Mar 2024 16:00:23 +0800 Subject: [PATCH] Add pre-commit hook to sync template context vars --- airflow/models/taskinstance.py | 9 +- airflow/utils/context.py | 7 +- airflow/utils/context.pyi | 4 +- docs/apache-airflow/templates-ref.rst | 17 ++- .../pre_commit_template_context_key_sync.py | 131 ++++++++++++++++++ 5 files changed, 156 insertions(+), 12 deletions(-) create mode 100644 scripts/ci/pre_commit/pre_commit_template_context_key_sync.py diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7619d069891db..7214a01fb0fd7 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -752,10 +752,11 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti except NotMapped: expanded_ti_count = None - # NOTE: If you add anything to this dict, make sure to also update the - # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in - # airflow/utils/context.py! - context = { + # NOTE: If you add to this dict, make sure to also update the following: + # * Context in airflow/utils/context.pyi + # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py + # * Table in docs/apache-airflow/templates-ref.rst + context: dict[str, Any] = { "conf": conf, "dag": dag, "dag_run": dag_run, diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 46ed3ef5a88e3..3501ca7dbc22a 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -44,8 +44,10 @@ if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator -# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi. -KNOWN_CONTEXT_KEYS = { +# NOTE: Please keep this in sync with the following: +# * Context in airflow/utils/context.pyi. +# * Table in docs/apache-airflow/templates-ref.rst +KNOWN_CONTEXT_KEYS: set[str] = { "conf", "conn", "dag", @@ -74,6 +76,7 @@ "prev_execution_date_success", "prev_start_date_success", "prev_end_date_success", + "reason", "run_id", "task", "task_instance", diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 124cd9c8c4ec6..eb08201248173 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -55,7 +55,9 @@ class VariableAccessor: class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... -# NOTE: Please keep this in sync with KNOWN_CONTEXT_KEYS in airflow/utils/context.py. +# NOTE: Please keep this in sync with the following: +# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py +# * Table in docs/apache-airflow/templates-ref.rst class Context(TypedDict, total=False): conf: AirflowConfigParser conn: Any diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index 0929c8b0c08fa..dd05fcc831c70 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -38,17 +38,20 @@ Variable Type Description =========================================== ===================== =================================================================== ``{{ data_interval_start }}`` `pendulum.DateTime`_ Start of the data interval. Added in version 2.2. ``{{ data_interval_end }}`` `pendulum.DateTime`_ End of the data interval. Added in version 2.2. +``{{ logical_date }}`` `pendulum.DateTime`_ | A date-time that logically identifies the current DAG run. This value does not contain any semantics, but is simply a value for identification. + | Use ``data_interval_start`` and ``date_interval_end`` instead if you want a value that has real-world semantics, + | such as to get a slice of rows from the database based on timestamps. ``{{ ds }}`` str | The DAG run's logical date as ``YYYY-MM-DD``. - | Same as ``{{ dag_run.logical_date | ds }}``. -``{{ ds_nodash }}`` str Same as ``{{ dag_run.logical_date | ds_nodash }}``. + | Same as ``{{ logical_date | ds }}``. +``{{ ds_nodash }}`` str Same as ``{{ logical_date | ds_nodash }}``. ``{{ exception }}`` None | str | | Error occurred while running task instance. Exception | KeyboardInterrupt | -``{{ ts }}`` str | Same as ``{{ dag_run.logical_date | ts }}``. +``{{ ts }}`` str | Same as ``{{ logical_date | ts }}``. | Example: ``2018-01-01T00:00:00+00:00``. -``{{ ts_nodash_with_tz }}`` str | Same as ``{{ dag_run.logical_date | ts_nodash_with_tz }}``. +``{{ ts_nodash_with_tz }}`` str | Same as ``{{ logical_date | ts_nodash_with_tz }}``. | Example: ``20180101T000000+0000``. -``{{ ts_nodash }}`` str | Same as ``{{ dag_run.logical_date | ts_nodash }}``. +``{{ ts_nodash }}`` str | Same as ``{{ logical_date | ts_nodash }}``. | Example: ``20180101T000000``. ``{{ prev_data_interval_start_success }}`` `pendulum.DateTime`_ | Start of the data interval of the prior successful :class:`~airflow.models.dagrun.DagRun`. | ``None`` | Added in version 2.2. @@ -56,6 +59,10 @@ Variable Type Description | ``None`` | Added in version 2.2. ``{{ prev_start_date_success }}`` `pendulum.DateTime`_ Start date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available). | ``None`` +``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available). + | ``None`` +``{{ inlets }}`` list List of inlets declared on the task. +``{{ outlets }}`` list List of outlets declared on the task. ``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs `. ``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators` ``{{ macros }}`` | A reference to the macros package. See Macros_ below. diff --git a/scripts/ci/pre_commit/pre_commit_template_context_key_sync.py b/scripts/ci/pre_commit/pre_commit_template_context_key_sync.py new file mode 100644 index 0000000000000..a26615f9f2d06 --- /dev/null +++ b/scripts/ci/pre_commit/pre_commit_template_context_key_sync.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import ast +import pathlib +import re +import sys +import typing + +ROOT_DIR = pathlib.Path(__file__).resolve().parents[3] + +TASKINSTANCE_PY = ROOT_DIR.joinpath("airflow", "models", "taskinstance.py") +CONTEXT_PY = ROOT_DIR.joinpath("airflow", "utils", "context.py") +CONTEXT_PYI = ROOT_DIR.joinpath("airflow", "utils", "context.pyi") +TEMPLATES_REF_RST = ROOT_DIR.joinpath("docs", "apache-airflow", "templates-ref.rst") + + +def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]: + ti_mod = ast.parse(TASKINSTANCE_PY.read_text("utf-8"), str(TASKINSTANCE_PY)) + fn_get_template_context = next( + node + for node in ast.iter_child_nodes(ti_mod) + if isinstance(node, ast.FunctionDef) and node.name == "_get_template_context" + ) + st_context_value = next( + stmt.value + for stmt in fn_get_template_context.body + if isinstance(stmt, ast.AnnAssign) + and isinstance(stmt.target, ast.Name) + and stmt.target.id == "context" + ) + if not isinstance(st_context_value, ast.Dict): + raise ValueError("'context' is not assigned a dict literal") + for expr in st_context_value.keys: + if not isinstance(expr, ast.Constant) or not isinstance(expr.value, str): + raise ValueError("key in 'context' dict is not a str literal") + yield expr.value + + +def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]: + context_mod = ast.parse(CONTEXT_PY.read_text("utf-8"), str(CONTEXT_PY)) + st_known_context_keys = next( + stmt.value + for stmt in context_mod.body + if isinstance(stmt, ast.AnnAssign) + and isinstance(stmt.target, ast.Name) + and stmt.target.id == "KNOWN_CONTEXT_KEYS" + ) + if not isinstance(st_known_context_keys, ast.Set): + raise ValueError("'KNOWN_CONTEXT_KEYS' is not assigned a set literal") + for expr in st_known_context_keys.elts: + if not isinstance(expr, ast.Constant) or not isinstance(expr.value, str): + raise ValueError("item in 'KNOWN_CONTEXT_KEYS' set is not a str literal") + yield expr.value + + +def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]: + context_mod = ast.parse(CONTEXT_PYI.read_text("utf-8"), str(CONTEXT_PYI)) + cls_context = next( + node + for node in ast.iter_child_nodes(context_mod) + if isinstance(node, ast.ClassDef) and node.name == "Context" + ) + for stmt in cls_context.body: + if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name): + raise ValueError("key in 'Context' hint is not an annotated assignment") + yield stmt.target.id + + +def _iter_template_context_keys_from_documentation() -> typing.Iterator[str]: + # We can use docutils to actually parse, but regex is good enough for now. + # This should find names in the "Variable" and "Deprecated Variable" tables. + content = TEMPLATES_REF_RST.read_text("utf-8") + for match in re.finditer(r"^``{{ (?P\w+)(?P\.\w+)* }}`` ", content, re.MULTILINE): + yield match.group("name") + + +def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], docs_keys: set[str]) -> int: + # Added by PythonOperator and commonly used. + # Not listed in templates-ref (but in operator docs). + retn_keys.add("templates_dict") + docs_keys.add("templates_dict") + + # Only present in callbacks. Not listed in templates-ref (that doc is for task execution). + retn_keys.update(("exception", "reason", "try_number")) + docs_keys.update(("exception", "reason", "try_number")) + + check_candidates = [ + ("get_template_context()", retn_keys), + ("KNOWN_CONTEXT_KEYS", decl_keys), + ("Context type hint", hint_keys), + ("templates-ref", docs_keys), + ] + canonical_keys = set.union(*(s for _, s in check_candidates)) + + def _check_one(identifier: str, keys: set[str]) -> int: + if missing := canonical_keys.difference(retn_keys): + print("Missing template variables from", f"{identifier}:", ", ".join(sorted(missing))) + return len(missing) + + return sum(_check_one(identifier, keys) for identifier, keys in check_candidates) + + +def main() -> str | int | None: + retn_keys = set(_iter_template_context_keys_from_original_return()) + decl_keys = set(_iter_template_context_keys_from_declaration()) + hint_keys = set(_iter_template_context_keys_from_type_hints()) + docs_keys = set(_iter_template_context_keys_from_documentation()) + return _compare_keys(retn_keys, decl_keys, hint_keys, docs_keys) + + +if __name__ == "__main__": + sys.exit(main())