diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 5a6e7119ebd99..b65337a6f95a6 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2121,16 +2121,6 @@ scheduler: type: integer default: "20" see_also: ":ref:`scheduler:ha:tunables`" - parsing_pre_import_modules: - description: | - The scheduler reads dag files to extract the airflow modules that are going to be used, - and imports them ahead of time to avoid having to re-do it for each parsing process. - This flag can be set to ``False`` to disable this behavior in case an airflow module needs - to be freshly imported each time (at the cost of increased DAG parsing time). - version_added: 2.6.0 - type: boolean - example: ~ - default: "True" dag_stale_not_seen_duration: description: | Time in seconds after which dags, which were not updated by Dag Processor are deactivated. @@ -2485,3 +2475,13 @@ dag_processor: type: integer example: ~ default: "10" + parsing_pre_import_modules: + description: | + The dag_processor reads dag files to extract the airflow modules that are going to be used, + and imports them ahead of time to avoid having to re-do it for each parsing process. + This flag can be set to ``False`` to disable this behavior in case an airflow module needs + to be freshly imported each time (at the cost of increased DAG parsing time). + version_added: 2.6.0 + type: boolean + example: ~ + default: "True" diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index d2f2cce2dfbc4..d1bcb8015a0f9 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -372,6 +372,7 @@ def sensitive_config_values(self) -> set[tuple[str, str]]: ("api", "auto_refresh_interval"): ("webserver", "auto_refresh_interval", "3.1.0"), ("api", "require_confirmation_dag_change"): ("webserver", "require_confirmation_dag_change", "3.1.0"), ("api", "instance_name"): ("webserver", "instance_name", "3.1.0"), + ("dag_processor", "parsing_pre_import_modules"): ("scheduler", "parsing_pre_import_modules", "3.1.0"), } # A mapping of new section -> (old section, since_version). diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index d626338482420..bfc889542fffb 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import importlib import os import sys import traceback @@ -46,6 +47,7 @@ from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats +from airflow.utils.file import iter_airflow_imports if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger @@ -99,6 +101,27 @@ class DagFileParsingResult(BaseModel): ] +def _pre_import_airflow_modules(file_path: str, log: FilteringBoundLogger) -> None: + """ + Pre-import Airflow modules found in the given file. + + This prevents modules from being re-imported in each processing process, + saving CPU time and memory. + (The default value of "parsing_pre_import_modules" is set to True) + + :param file_path: Path to the file to scan for imports + :param log: Logger instance to use for warnings + """ + if not conf.getboolean("dag_processor", "parsing_pre_import_modules", fallback=True): + return + + for module in iter_airflow_imports(file_path): + try: + importlib.import_module(module) + except ModuleNotFoundError as e: + log.warning("Error when trying to pre-import module '%s' found in %s: %s", module, file_path, e) + + def _parse_file_entrypoint(): import structlog @@ -128,6 +151,7 @@ def _parse_file_entrypoint(): def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: # TODO: Set known_pool names on DagBag! + bag = DagBag( dag_folder=msg.file, bundle_path=msg.bundle_path, @@ -251,6 +275,10 @@ def start( # type: ignore[override] client: Client, **kwargs, ) -> Self: + logger = kwargs["logger"] + + _pre_import_airflow_modules(os.fspath(path), logger) + proc: Self = super().start(target=target, client=client, **kwargs) proc._on_child_started(callbacks, path, bundle_path) return proc diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1125d003c31d6..9b9afd69cf09a 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -38,6 +38,7 @@ DagFileParsingResult, DagFileProcessorProcess, _parse_file, + _pre_import_airflow_modules, ) from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator @@ -141,8 +142,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert "a.py" in resp.import_errors def test_top_level_variable_access( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client + self, + spy_agency: SpyAgency, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + inprocess_client, ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -159,6 +165,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -172,8 +179,13 @@ def dag_in_a_fn(): assert result.serialized_dags[0].dag_id == "test_abc" def test_top_level_variable_access_not_found( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client + self, + spy_agency: SpyAgency, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + inprocess_client, ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -188,6 +200,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -204,6 +217,7 @@ def dag_in_a_fn(): def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -219,6 +233,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -239,6 +254,7 @@ def dag_in_a_fn(): def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -260,6 +276,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -279,6 +296,7 @@ def dag_in_a_fn(): def test_top_level_connection_access( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -295,6 +313,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -308,6 +327,7 @@ def dag_in_a_fn(): assert result.serialized_dags[0].dag_id == "test_my_conn" def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -322,6 +342,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -354,6 +375,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl path=dag1_path, bundle_path=tmp_path, callbacks=[], + logger=MagicMock(), logger_filehandle=MagicMock(), client=inprocess_client, ) @@ -365,6 +387,65 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl assert result.import_errors == {} assert result.serialized_dags[0].dag_id == "dag_name" + def test__pre_import_airflow_modules_when_disabled(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "false"}), + patch("airflow.dag_processing.processor.iter_airflow_imports") as mock_iter, + ): + _pre_import_airflow_modules("test.py", logger) + + mock_iter.assert_not_called() + logger.warning.assert_not_called() + + def test__pre_import_airflow_modules_when_enabled(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch("airflow.dag_processing.processor.iter_airflow_imports", return_value=["airflow.models"]), + patch("airflow.dag_processing.processor.importlib.import_module") as mock_import, + ): + _pre_import_airflow_modules("test.py", logger) + + mock_import.assert_called_once_with("airflow.models") + logger.warning.assert_not_called() + + def test__pre_import_airflow_modules_warns_on_missing_module(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch( + "airflow.dag_processing.processor.iter_airflow_imports", return_value=["non_existent_module"] + ), + patch( + "airflow.dag_processing.processor.importlib.import_module", side_effect=ModuleNotFoundError() + ), + ): + _pre_import_airflow_modules("test.py", logger) + + logger.warning.assert_called_once() + warning_args = logger.warning.call_args[0] + assert "Error when trying to pre-import module" in warning_args[0] + assert "non_existent_module" in warning_args[1] + assert "test.py" in warning_args[2] + + def test__pre_import_airflow_modules_partial_success_and_warning(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch( + "airflow.dag_processing.processor.iter_airflow_imports", + return_value=["airflow.models", "non_existent_module"], + ), + patch( + "airflow.dag_processing.processor.importlib.import_module", + side_effect=[None, ModuleNotFoundError()], + ), + ): + _pre_import_airflow_modules("test.py", logger) + + assert logger.warning.call_count == 1 + def write_dag_in_a_fn_to_file(fn: Callable[[], None], folder: pathlib.Path) -> pathlib.Path: # Create the dag in a fn, and use inspect.getsource to write it to a file so that