From 1dc5464ee21645fbcfa4e7bdbfb16834a51d3bf3 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 18 Feb 2025 18:19:15 +0000 Subject: [PATCH] Enable accessing Variables from the top level of the DAG files Since I want to maintain the property of being able to run a DAG processor without the Execution API server running, and since the Dag Processor Manager already has a database connection I have chosen to run the FastAPI execution server in process. To achive this I make use of two features: - The first is the abilty to provide an httpx.Client with a Transport object that has an WSGI appliction to not send real requests, but to instead call the WSGI app directly to service the request - The second is a2wsgi. Since we are making a call from with in a synchronus context we have to give httpx a WSGI (if we were making Async requests we could give httpx an ASGI app directly), and FastAPI at it's outer layers is an async framework (even if it supports running sync routes) we need to somehow wrap the async call to return a sync result. a2wsgi does this for us by using a async loop off the main thread. I tested this with a simple DAG file initially: ```python import time import sys from airflow.decorators import dag, task from airflow.sdk import Variable from airflow.utils.session import create_session if Variable.get("hi", default=None): raise RuntimeError("Var hi was defined") @dag(schedule=None) def hi(): @task() def hello(): print("hello") time.sleep(3) print("goodbye") print("err mesg", file=sys.stderr) hello() hi() ``` If the variable is defined it results in an import error. If the variable is not defined you get the DAG defined. --- airflow/api_fastapi/app.py | 2 +- airflow/api_fastapi/execution_api/app.py | 42 ++++++++- airflow/dag_processing/processor.py | 64 +++++++++++--- airflow/models/variable.py | 24 +++++ hatch_build.py | 1 + task_sdk/src/airflow/sdk/__init__.py | 2 + task_sdk/src/airflow/sdk/api/client.py | 4 +- .../src/airflow/sdk/definitions/variable.py | 13 +++ tests/api_fastapi/test_app.py | 6 +- tests/dag_processing/test_processor.py | 87 +++++++++---------- 10 files changed, 183 insertions(+), 62 deletions(-) diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index cd7985a3eb97f..3c812cde05645 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -88,7 +88,7 @@ def create_app(apps: str = "all") -> FastAPI: init_middlewares(app) if "execution" in apps_list or "all" in apps_list: - task_exec_api_app = create_task_execution_api_app(app) + task_exec_api_app = create_task_execution_api_app() init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 2b85be363f25f..cee567f2d4999 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -18,10 +18,16 @@ from __future__ import annotations from contextlib import asynccontextmanager +from functools import cached_property +from typing import TYPE_CHECKING +import attrs from fastapi import FastAPI from fastapi.openapi.utils import get_openapi +if TYPE_CHECKING: + import httpx + @asynccontextmanager async def lifespan(app: FastAPI): @@ -30,7 +36,7 @@ async def lifespan(app: FastAPI): yield -def create_task_execution_api_app(app: FastAPI) -> FastAPI: +def create_task_execution_api_app() -> FastAPI: """Create FastAPI app for task execution API.""" from airflow.api_fastapi.execution_api.routes import execution_api_router @@ -88,3 +94,37 @@ def get_extra_schemas() -> dict[str, dict]: # as that has different payload requirements "TerminalTIState": {"type": "string", "enum": list(TerminalTIState)}, } + + +@attrs.define() +class InProcessExecuctionAPI: + """ + A helper class to make it possible to run the ExecutionAPI "in-process". + + The sync version of this makes use of a2wsgi which runs the async loop in a separate thread. This is + needed so that we can use the sync httpx client + """ + + _app: FastAPI | None = None + + @cached_property + def app(self): + if not self._app: + from airflow.api_fastapi.execution_api.app import create_task_execution_api_app + + self._app = create_task_execution_api_app() + + return self._app + + @cached_property + def transport(self) -> httpx.WSGITransport: + import httpx + from a2wsgi import ASGIMiddleware + + return httpx.WSGITransport(app=ASGIMiddleware(self.app)) # type: ignore[arg-type] + + @cached_property + def atransport(self) -> httpx.ASGITransport: + import httpx + + return httpx.ASGITransport(app=self.app) diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 058d59c9ed1d7..7360f0b7c7a5d 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import functools import os import sys import traceback @@ -32,7 +33,7 @@ ) from airflow.configuration import conf from airflow.models.dagbag import DagBag -from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, GetVariable, VariableResult from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats @@ -40,9 +41,21 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger + from airflow.api_fastapi.execution_api.app import InProcessExecuctionAPI + from airflow.sdk.api.client import Client from airflow.sdk.definitions.context import Context from airflow.typing_compat import Self +ToManager = Annotated[ + Union["DagFileParsingResult", GetConnection, GetVariable], + Field(discriminator="type"), +] + +ToDagProcessor = Annotated[ + Union["DagFileParseRequest", ConnectionResult, VariableResult], + Field(discriminator="type"), +] + def _parse_file_entrypoint(): import os @@ -51,19 +64,24 @@ def _parse_file_entrypoint(): from airflow.sdk.execution_time import task_runner from airflow.settings import configure_orm + # Parse DAG file, send JSON back up! # We need to reconfigure the orm here, as DagFileProcessorManager does db queries for bundles, and # the session across forks blows things up. configure_orm() - comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( input=sys.stdin, - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) + msg = comms_decoder.get_message() + if not isinstance(msg, DagFileParseRequest): + raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) @@ -188,10 +206,12 @@ class DagFileParsingResult(BaseModel): type: Literal["DagFileParsingResult"] = "DagFileParsingResult" -ToParent = Annotated[ - Union[DagFileParsingResult, GetConnection, GetVariable], - Field(discriminator="type"), -] +@functools.cache +def in_process_api_server() -> InProcessExecuctionAPI: + from airflow.api_fastapi.execution_api.app import InProcessExecuctionAPI + + api = InProcessExecuctionAPI() + return api @attrs.define(kw_only=True) @@ -207,7 +227,7 @@ class DagFileProcessorProcess(WatchedSubprocess): """ parsing_result: DagFileParsingResult | None = None - decoder: ClassVar[TypeAdapter[ToParent]] = TypeAdapter[ToParent](ToParent) + decoder: ClassVar[TypeAdapter[ToManager]] = TypeAdapter[ToManager](ToManager) @classmethod def start( # type: ignore[override] @@ -237,12 +257,36 @@ def _on_child_started( ) self.stdin.write(msg.model_dump_json().encode() + b"\n") - def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: # type: ignore[override] - # TODO: GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable + @functools.cached_property + def client(self) -> Client: + from airflow.sdk.api.client import Client + + client = Client(base_url=None, token="", dry_run=True, transport=in_process_api_server().transport) + # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` + client.base_url = "http://in-process.invalid./" # type: ignore[assignment] + return client + + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse + resp = None if isinstance(msg, DagFileParsingResult): self.parsing_result = msg return + elif isinstance(msg, GetConnection): + conn = self.client.connections.get(msg.conn_id) + if isinstance(conn, ConnectionResponse): + conn_result = ConnectionResult.from_conn_response(conn) + resp = conn_result.model_dump_json(exclude_unset=True).encode() + else: + resp = conn.model_dump_json().encode() + elif isinstance(msg, GetVariable): + var = self.client.variables.get(msg.key) + if isinstance(var, VariableResponse): + var_result = VariableResult.from_variable_response(var) + resp = var_result.model_dump_json(exclude_unset=True).encode() + else: + resp = var.model_dump_json().encode() else: log.error("Unhandled request", msg=msg) return diff --git a/airflow/models/variable.py b/airflow/models/variable.py index b4568ad09c489..d2532e1203e22 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -19,6 +19,8 @@ import json import logging +import sys +import warnings from typing import TYPE_CHECKING, Any from sqlalchemy import Boolean, Column, Integer, String, Text, delete, select @@ -136,6 +138,28 @@ def get( :param default_var: Default value of the Variable if the Variable doesn't exist :param deserialize_json: Deserialize the value to a Python dict """ + # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still + # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big + # back-compat layer + + # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) + # and should use the Task SDK API server path + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + warnings.warn( + "Using Variable.get from `airflow.models` is deprecated. Please use `from airflow.sdk import" + "Variable` instead", + DeprecationWarning, + stacklevel=1, + ) + from airflow.sdk import Variable as TaskSDKVariable + from airflow.sdk.definitions._internal.types import NOTSET + + return TaskSDKVariable.get( + key, + default=NOTSET if default_var is cls.__NO_DEFAULT_SENTINEL else default_var, + deserialize_json=deserialize_json, + ) + var_val = Variable.get_variable_from_secrets(key=key) if var_val is None: if default_var is not cls.__NO_DEFAULT_SENTINEL: diff --git a/hatch_build.py b/hatch_build.py index d5bd397570dcc..ae611b55778cc 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -342,6 +342,7 @@ } DEPENDENCIES = [ + "a2wsgi>=1.10.8", # Alembic is important to handle our migrations in predictable and performant way. It is developed # together with SQLAlchemy. Our experience with Alembic is that it very stable in minor version # The 1.13.0 of alembic marked some migration code as SQLAlchemy 2+ only so we limit it to 1.13.1 diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 95b08be37aa09..1fabbd2bd0308 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -29,6 +29,7 @@ "Label", "MappedOperator", "TaskGroup", + "Variable", "XComArg", "dag", "get_current_context", @@ -46,6 +47,7 @@ from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.sdk.definitions.variable import Variable from airflow.sdk.definitions.xcom_arg import XComArg __lazy_imports: dict[str, str] = { diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 5d81c32eefa4b..6e5e6d5a2de9a 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -384,8 +384,8 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * if dry_run: # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a # real client, but just don't make any HTTP requests - kwargs["transport"] = httpx.MockTransport(noop_handler) - kwargs["base_url"] = "dry-run://server" + kwargs.setdefault("transport", httpx.MockTransport(noop_handler)) + kwargs.setdefault("base_url", "dry-run://server") else: kwargs["base_url"] = base_url pyver = f"{'.'.join(map(str, sys.version_info[:3]))}" diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py index 5f458580065c5..46742e965ed5e 100644 --- a/task_sdk/src/airflow/sdk/definitions/variable.py +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -21,6 +21,8 @@ import attrs +from airflow.sdk.definitions._internal.types import NOTSET + @attrs.define class Variable: @@ -39,3 +41,14 @@ class Variable: description: str | None = None # TODO: Extend this definition for reading/writing variables without context + @classmethod + def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): + from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType + from airflow.sdk.execution_time.context import _get_variable + + try: + return _get_variable(key, deserialize_json=deserialize_json).value + except AirflowRuntimeError as e: + if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: + return default + raise diff --git a/tests/api_fastapi/test_app.py b/tests/api_fastapi/test_app.py index 249fbdba15aa4..e2a15988d4b47 100644 --- a/tests/api_fastapi/test_app.py +++ b/tests/api_fastapi/test_app.py @@ -57,10 +57,10 @@ def test_core_api_app( def test_execution_api_app( mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client ): - test_app = client(apps="execution").app + client(apps="execution") # Assert that execution-related functions were called - mock_create_task_exec_api.assert_called_once_with(test_app) + mock_create_task_exec_api.assert_called_once() # Assert that core-related functions were NOT called mock_init_dag_bag.assert_not_called() @@ -91,4 +91,4 @@ def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_plugins.assert_called_once_with(test_app) # Assert that execution-related functions were also called - mock_create_task_exec_api.assert_called_once_with(test_app) + mock_create_task_exec_api.assert_called_once_with() diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index a01d3edb2ff3b..a20be18d3583a 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -17,10 +17,12 @@ # under the License. from __future__ import annotations +import inspect import pathlib import sys +import textwrap from socket import socketpair -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from unittest.mock import patch import pytest @@ -32,6 +34,7 @@ from airflow.dag_processing.processor import ( DagFileParseRequest, DagFileParsingResult, + DagFileProcessorProcess, _parse_file, ) from airflow.models import DagBag, TaskInstance @@ -51,16 +54,9 @@ pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2016, 1, 1) -PY311 = sys.version_info >= (3, 11) - -# Include the words "airflow" and "dag" in the file contents, -# tricking airflow into thinking these -# files contain a DAG (otherwise Airflow will skip them) -PARSEABLE_DAG_FILE_CONTENTS = '"airflow DAG"' # Filename to be used for dags that are created in an ad-hoc manner and can be removed/ # created at runtime -TEMP_DAG_FILENAME = "temp_dag.py" TEST_DAG_FOLDER = pathlib.Path(__file__).parents[1].resolve() / "dags" @@ -133,44 +129,45 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert resp.import_errors is not None assert "a.py" in resp.import_errors + # @pytest.mark.execution_timeout(10) + def test_top_level_variable_access( + self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ): + # Create the dag in a fn, and use inspect.getsource to write it to a file so that + # a) the test dag is directly viewable here in the tests + # b) that it shows to IDEs/mypy etc. + def dag_in_a_fn(): + from airflow.sdk import DAG, Variable + + with DAG(f"test_{Variable.get('myvar')}"): + ... + + path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) + + monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc") + proc = DagFileProcessorProcess.start( + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + ) + + while not proc.is_ready: + proc._service_subprocess(0.1) -# @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) -# @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) -# @mock.patch("airflow.dag_processing.processor.redirect_stdout") -# def test_dag_parser_output_when_logging_to_stdout(self, mock_redirect_stdout_for_file): -# processor = DagFileProcessorProcess( -# file_path="abc.txt", -# dag_directory=[], -# callback_requests=[], -# ) -# processor._run_file_processor( -# result_channel=MagicMock(), -# parent_channel=MagicMock(), -# file_path="fake_file_path", -# thread_name="fake_thread_name", -# callback_requests=[], -# dag_directory=[], -# ) -# mock_redirect_stdout_for_file.assert_not_called() -# -# @conf_vars({("logging", "dag_processor_log_target"): "file"}) -# @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) -# @mock.patch("airflow.dag_processing.processor.redirect_stdout") -# def test_dag_parser_output_when_logging_to_file(self, mock_redirect_stdout_for_file): -# processor = DagFileProcessorProcess( -# file_path="abc.txt", -# dag_directory=[], -# callback_requests=[], -# ) -# processor._run_file_processor( -# result_channel=MagicMock(), -# parent_channel=MagicMock(), -# file_path="fake_file_path", -# thread_name="fake_thread_name", -# callback_requests=[], -# dag_directory=[], -# ) -# mock_redirect_stdout_for_file.assert_called_once() + result = proc.parsing_result + assert result is not None + assert result.import_errors == {} + assert result.serialized_dags[0].dag_id == "test_abc" + + +def write_dag_in_a_fn_to_file(fn: Callable[[], None], folder: pathlib.Path) -> pathlib.Path: + assert folder.is_dir() + name = fn.__name__ + path = folder.joinpath(name + ".py") + path.write_text(textwrap.dedent(inspect.getsource(fn)) + f"\n\n{name}()") + + return path @pytest.fixture