Skip to content

Commit

Permalink
Add pre-commit hook to sync template context vars (#38579)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Mar 28, 2024
1 parent 2589248 commit d4c2ea4
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 12 deletions.
9 changes: 5 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,11 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti
except NotMapped:
expanded_ti_count = None

# NOTE: If you add anything to this dict, make sure to also update the
# definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in
# airflow/utils/context.py!
context = {
# NOTE: If you add to this dict, make sure to also update the following:
# * Context in airflow/utils/context.pyi
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
context: dict[str, Any] = {
"conf": conf,
"dag": dag,
"dag_run": dag_run,
Expand Down
7 changes: 5 additions & 2 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator

# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
KNOWN_CONTEXT_KEYS = {
# NOTE: Please keep this in sync with the following:
# * Context in airflow/utils/context.pyi.
# * Table in docs/apache-airflow/templates-ref.rst
KNOWN_CONTEXT_KEYS: set[str] = {
"conf",
"conn",
"dag",
Expand Down Expand Up @@ -74,6 +76,7 @@
"prev_execution_date_success",
"prev_start_date_success",
"prev_end_date_success",
"reason",
"run_id",
"task",
"task_instance",
Expand Down
4 changes: 3 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class VariableAccessor:
class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

# NOTE: Please keep this in sync with KNOWN_CONTEXT_KEYS in airflow/utils/context.py.
# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
class Context(TypedDict, total=False):
conf: AirflowConfigParser
conn: Any
Expand Down
17 changes: 12 additions & 5 deletions docs/apache-airflow/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,31 @@ Variable Type Description
=========================================== ===================== ===================================================================
``{{ data_interval_start }}`` `pendulum.DateTime`_ Start of the data interval. Added in version 2.2.
``{{ data_interval_end }}`` `pendulum.DateTime`_ End of the data interval. Added in version 2.2.
``{{ logical_date }}`` `pendulum.DateTime`_ | A date-time that logically identifies the current DAG run. This value does not contain any semantics, but is simply a value for identification.
| Use ``data_interval_start`` and ``date_interval_end`` instead if you want a value that has real-world semantics,
| such as to get a slice of rows from the database based on timestamps.
``{{ ds }}`` str | The DAG run's logical date as ``YYYY-MM-DD``.
| Same as ``{{ dag_run.logical_date | ds }}``.
``{{ ds_nodash }}`` str Same as ``{{ dag_run.logical_date | ds_nodash }}``.
| Same as ``{{ logical_date | ds }}``.
``{{ ds_nodash }}`` str Same as ``{{ logical_date | ds_nodash }}``.
``{{ exception }}`` None | str | | Error occurred while running task instance.
Exception |
KeyboardInterrupt |
``{{ ts }}`` str | Same as ``{{ dag_run.logical_date | ts }}``.
``{{ ts }}`` str | Same as ``{{ logical_date | ts }}``.
| Example: ``2018-01-01T00:00:00+00:00``.
``{{ ts_nodash_with_tz }}`` str | Same as ``{{ dag_run.logical_date | ts_nodash_with_tz }}``.
``{{ ts_nodash_with_tz }}`` str | Same as ``{{ logical_date | ts_nodash_with_tz }}``.
| Example: ``20180101T000000+0000``.
``{{ ts_nodash }}`` str | Same as ``{{ dag_run.logical_date | ts_nodash }}``.
``{{ ts_nodash }}`` str | Same as ``{{ logical_date | ts_nodash }}``.
| Example: ``20180101T000000``.
``{{ prev_data_interval_start_success }}`` `pendulum.DateTime`_ | Start of the data interval of the prior successful :class:`~airflow.models.dagrun.DagRun`.
| ``None`` | Added in version 2.2.
``{{ prev_data_interval_end_success }}`` `pendulum.DateTime`_ | End of the data interval of the prior successful :class:`~airflow.models.dagrun.DagRun`.
| ``None`` | Added in version 2.2.
``{{ prev_start_date_success }}`` `pendulum.DateTime`_ Start date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available).
| ``None``
``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available).
| ``None``
``{{ inlets }}`` list List of inlets declared on the task.
``{{ outlets }}`` list List of outlets declared on the task.
``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs <core-concepts/dags>`.
``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators`
``{{ macros }}`` | A reference to the macros package. See Macros_ below.
Expand Down
131 changes: 131 additions & 0 deletions scripts/ci/pre_commit/pre_commit_template_context_key_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import ast
import pathlib
import re
import sys
import typing

ROOT_DIR = pathlib.Path(__file__).resolve().parents[3]

TASKINSTANCE_PY = ROOT_DIR.joinpath("airflow", "models", "taskinstance.py")
CONTEXT_PY = ROOT_DIR.joinpath("airflow", "utils", "context.py")
CONTEXT_PYI = ROOT_DIR.joinpath("airflow", "utils", "context.pyi")
TEMPLATES_REF_RST = ROOT_DIR.joinpath("docs", "apache-airflow", "templates-ref.rst")


def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]:
ti_mod = ast.parse(TASKINSTANCE_PY.read_text("utf-8"), str(TASKINSTANCE_PY))
fn_get_template_context = next(
node
for node in ast.iter_child_nodes(ti_mod)
if isinstance(node, ast.FunctionDef) and node.name == "_get_template_context"
)
st_context_value = next(
stmt.value
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "context"
)
if not isinstance(st_context_value, ast.Dict):
raise ValueError("'context' is not assigned a dict literal")
for expr in st_context_value.keys:
if not isinstance(expr, ast.Constant) or not isinstance(expr.value, str):
raise ValueError("key in 'context' dict is not a str literal")
yield expr.value


def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]:
context_mod = ast.parse(CONTEXT_PY.read_text("utf-8"), str(CONTEXT_PY))
st_known_context_keys = next(
stmt.value
for stmt in context_mod.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "KNOWN_CONTEXT_KEYS"
)
if not isinstance(st_known_context_keys, ast.Set):
raise ValueError("'KNOWN_CONTEXT_KEYS' is not assigned a set literal")
for expr in st_known_context_keys.elts:
if not isinstance(expr, ast.Constant) or not isinstance(expr.value, str):
raise ValueError("item in 'KNOWN_CONTEXT_KEYS' set is not a str literal")
yield expr.value


def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]:
context_mod = ast.parse(CONTEXT_PYI.read_text("utf-8"), str(CONTEXT_PYI))
cls_context = next(
node
for node in ast.iter_child_nodes(context_mod)
if isinstance(node, ast.ClassDef) and node.name == "Context"
)
for stmt in cls_context.body:
if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name):
raise ValueError("key in 'Context' hint is not an annotated assignment")
yield stmt.target.id


def _iter_template_context_keys_from_documentation() -> typing.Iterator[str]:
# We can use docutils to actually parse, but regex is good enough for now.
# This should find names in the "Variable" and "Deprecated Variable" tables.
content = TEMPLATES_REF_RST.read_text("utf-8")
for match in re.finditer(r"^``{{ (?P<name>\w+)(?P<subname>\.\w+)* }}`` ", content, re.MULTILINE):
yield match.group("name")


def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], docs_keys: set[str]) -> int:
# Added by PythonOperator and commonly used.
# Not listed in templates-ref (but in operator docs).
retn_keys.add("templates_dict")
docs_keys.add("templates_dict")

# Only present in callbacks. Not listed in templates-ref (that doc is for task execution).
retn_keys.update(("exception", "reason", "try_number"))
docs_keys.update(("exception", "reason", "try_number"))

check_candidates = [
("get_template_context()", retn_keys),
("KNOWN_CONTEXT_KEYS", decl_keys),
("Context type hint", hint_keys),
("templates-ref", docs_keys),
]
canonical_keys = set.union(*(s for _, s in check_candidates))

def _check_one(identifier: str, keys: set[str]) -> int:
if missing := canonical_keys.difference(retn_keys):
print("Missing template variables from", f"{identifier}:", ", ".join(sorted(missing)))
return len(missing)

return sum(_check_one(identifier, keys) for identifier, keys in check_candidates)


def main() -> str | int | None:
retn_keys = set(_iter_template_context_keys_from_original_return())
decl_keys = set(_iter_template_context_keys_from_declaration())
hint_keys = set(_iter_template_context_keys_from_type_hints())
docs_keys = set(_iter_template_context_keys_from_documentation())
return _compare_keys(retn_keys, decl_keys, hint_keys, docs_keys)


if __name__ == "__main__":
sys.exit(main())

0 comments on commit d4c2ea4

Please sign in to comment.