diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 3593928aaf23f..c143e682608a6 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -48,6 +48,7 @@ from uuid6 import uuid7 import airflow.models +from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.configuration import conf from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.collection import update_dag_parsing_results_in_db @@ -80,6 +81,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle + from airflow.sdk.api.client import Client class DagParsingStat(NamedTuple): @@ -213,6 +215,9 @@ class DagFileProcessorManager(LoggingMixin): _force_refresh_bundles: set[str] = attrs.field(factory=set, init=False) """List of bundles that need to be force refreshed in the next loop""" + _api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI) + """API server to interact with Metadata DB""" + def register_exit_signals(self): """Register signals that stop child processes.""" signal.signal(signal.SIGINT, self._exit_gracefully) @@ -867,6 +872,15 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo): underlying_logger, processors=processors, logger_name="processor" ).bind(), logger_filehandle + @functools.cached_property + def client(self) -> Client: + from airflow.sdk.api.client import Client + + client = Client(base_url=None, token="", dry_run=True, transport=self._api_server.transport) + # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` + client.base_url = "http://in-process.invalid./" # type: ignore[assignment] + return client + def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id = uuid7() @@ -881,6 +895,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: selector=self.selector, logger=logger, logger_filehandle=logger_filehandle, + client=self.client, ) def _start_new_processes(self): diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 79b3a9816c25f..01ac41f7a326a 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import functools import os import sys import traceback @@ -239,6 +238,9 @@ class DagFileProcessorProcess(WatchedSubprocess): parsing_result: DagFileParsingResult | None = None decoder: ClassVar[TypeAdapter[ToManager]] = TypeAdapter[ToManager](ToManager) + client: Client + """The HTTP client to use for communication with the API server.""" + @classmethod def start( # type: ignore[override] cls, @@ -247,9 +249,10 @@ def start( # type: ignore[override] bundle_path: Path, callbacks: list[CallbackRequest], target: Callable[[], None] = _parse_file_entrypoint, + client: Client, **kwargs, ) -> Self: - proc: Self = super().start(target=target, **kwargs) + proc: Self = super().start(target=target, client=client, **kwargs) proc._on_child_started(callbacks, path, bundle_path) return proc @@ -267,15 +270,6 @@ def _on_child_started( ) self.send_msg(msg) - @functools.cached_property - def client(self) -> Client: - from airflow.sdk.api.client import Client - - client = Client(base_url=None, token="", dry_run=True, transport=in_process_api_server().transport) - # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` - client.base_url = "http://in-process.invalid./" # type: ignore[assignment] - return client - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index ebf684c678843..c9e974b8cdb9a 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -147,6 +147,7 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces stdin=write_end, requests_fd=123, logger_filehandle=logger_filehandle, + client=MagicMock(), ) if start_time: ret.start_time = start_time @@ -899,6 +900,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): selector=mock.ANY, logger=mock_logger, logger_filehandle=mock_filehandle, + client=mock.ANY, ), mock.call( id=mock.ANY, @@ -908,6 +910,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): selector=mock.ANY, logger=mock_logger, logger_filehandle=mock_filehandle, + client=mock.ANY, ), ] # And removed from the queue diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 6a3eb978067ee..28ce7a8c23fe6 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -29,6 +29,7 @@ import structlog from pydantic import TypeAdapter +from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest from airflow.configuration import conf from airflow.dag_processing.processor import ( @@ -40,6 +41,7 @@ from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk.api.client import Client from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone from airflow.utils.session import create_session @@ -67,6 +69,15 @@ def disable_load_example(): yield +@pytest.fixture +def inprocess_client(): + """Provides an in-process Client backed by a single API server.""" + api = InProcessExecutionAPI() + client = Client(base_url=None, token="", dry_run=True, transport=api.transport) + client.base_url = "http://in-process.invalid/" # type: ignore[assignment] + return client + + @pytest.mark.usefixtures("disable_load_example") class TestDagFileProcessor: def _process_file( @@ -130,7 +141,7 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert "a.py" in resp.import_errors def test_top_level_variable_access( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client ): logger_filehandle = MagicMock() @@ -144,7 +155,12 @@ def dag_in_a_fn(): monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc") proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -156,7 +172,7 @@ def dag_in_a_fn(): assert result.serialized_dags[0].dag_id == "test_abc" def test_top_level_variable_access_not_found( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client ): logger_filehandle = MagicMock() @@ -168,7 +184,12 @@ def dag_in_a_fn(): path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -180,7 +201,7 @@ def dag_in_a_fn(): if result.import_errors: assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values())) - def test_top_level_variable_set(self, tmp_path: pathlib.Path): + def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM logger_filehandle = MagicMock() @@ -194,7 +215,12 @@ def dag_in_a_fn(): path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -210,7 +236,7 @@ def dag_in_a_fn(): assert len(all_vars) == 1 assert all_vars[0].key == "mykey" - def test_top_level_variable_delete(self, tmp_path: pathlib.Path): + def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM logger_filehandle = MagicMock() @@ -230,7 +256,12 @@ def dag_in_a_fn(): path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -245,7 +276,9 @@ def dag_in_a_fn(): all_vars = session.query(VariableORM).all() assert len(all_vars) == 0 - def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): + def test_top_level_connection_access( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client + ): logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -259,7 +292,12 @@ def dag_in_a_fn(): monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", '{"conn_type": "aws"}') proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -270,7 +308,7 @@ def dag_in_a_fn(): assert result.import_errors == {} assert result.serialized_dags[0].dag_id == "test_my_conn" - def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path): + def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client): logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -282,7 +320,12 @@ def dag_in_a_fn(): path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) proc = DagFileProcessorProcess.start( - id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + id=1, + path=path, + bundle_path=tmp_path, + callbacks=[], + logger_filehandle=logger_filehandle, + client=inprocess_client, ) while not proc.is_ready: @@ -294,7 +337,7 @@ def dag_in_a_fn(): if result.import_errors: assert "CONNECTION_NOT_FOUND" in next(iter(result.import_errors.values())) - def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path): + def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_client): tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'") dag1_path = tmp_path.joinpath("dag1.py") @@ -314,6 +357,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path): bundle_path=tmp_path, callbacks=[], logger_filehandle=MagicMock(), + client=inprocess_client, ) while not proc.is_ready: proc._service_subprocess(0.1)