Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2121,16 +2121,6 @@ scheduler:
type: integer
default: "20"
see_also: ":ref:`scheduler:ha:tunables`"
parsing_pre_import_modules:
description: |
The scheduler reads dag files to extract the airflow modules that are going to be used,
and imports them ahead of time to avoid having to re-do it for each parsing process.
This flag can be set to ``False`` to disable this behavior in case an airflow module needs
to be freshly imported each time (at the cost of increased DAG parsing time).
version_added: 2.6.0
type: boolean
example: ~
default: "True"
dag_stale_not_seen_duration:
description: |
Time in seconds after which dags, which were not updated by Dag Processor are deactivated.
Expand Down Expand Up @@ -2485,3 +2475,13 @@ dag_processor:
type: integer
example: ~
default: "10"
parsing_pre_import_modules:
description: |
The dag_processor reads dag files to extract the airflow modules that are going to be used,
and imports them ahead of time to avoid having to re-do it for each parsing process.
This flag can be set to ``False`` to disable this behavior in case an airflow module needs
to be freshly imported each time (at the cost of increased DAG parsing time).
version_added: 2.6.0
type: boolean
example: ~
default: "True"
1 change: 1 addition & 0 deletions airflow-core/src/airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def sensitive_config_values(self) -> set[tuple[str, str]]:
("api", "auto_refresh_interval"): ("webserver", "auto_refresh_interval", "3.1.0"),
("api", "require_confirmation_dag_change"): ("webserver", "require_confirmation_dag_change", "3.1.0"),
("api", "instance_name"): ("webserver", "instance_name", "3.1.0"),
("dag_processor", "parsing_pre_import_modules"): ("scheduler", "parsing_pre_import_modules", "3.1.0"),
}

# A mapping of new section -> (old section, since_version).
Expand Down
28 changes: 28 additions & 0 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import importlib
import os
import sys
import traceback
Expand Down Expand Up @@ -46,6 +47,7 @@
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.stats import Stats
from airflow.utils.file import iter_airflow_imports

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
Expand Down Expand Up @@ -99,6 +101,27 @@ class DagFileParsingResult(BaseModel):
]


def _pre_import_airflow_modules(file_path: str, log: FilteringBoundLogger) -> None:
"""
Pre-import Airflow modules found in the given file.

This prevents modules from being re-imported in each processing process,
saving CPU time and memory.
(The default value of "parsing_pre_import_modules" is set to True)

:param file_path: Path to the file to scan for imports
:param log: Logger instance to use for warnings
"""
if not conf.getboolean("dag_processor", "parsing_pre_import_modules", fallback=True):
return

for module in iter_airflow_imports(file_path):
try:
importlib.import_module(module)
except ModuleNotFoundError as e:
log.warning("Error when trying to pre-import module '%s' found in %s: %s", module, file_path, e)


def _parse_file_entrypoint():
import structlog

Expand Down Expand Up @@ -128,6 +151,7 @@ def _parse_file_entrypoint():

def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None:
# TODO: Set known_pool names on DagBag!

bag = DagBag(
dag_folder=msg.file,
bundle_path=msg.bundle_path,
Expand Down Expand Up @@ -251,6 +275,10 @@ def start( # type: ignore[override]
client: Client,
**kwargs,
) -> Self:
logger = kwargs["logger"]

_pre_import_airflow_modules(os.fspath(path), logger)

proc: Self = super().start(target=target, client=client, **kwargs)
proc._on_child_started(callbacks, path, bundle_path)
return proc
Expand Down
85 changes: 83 additions & 2 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DagFileParsingResult,
DagFileProcessorProcess,
_parse_file,
_pre_import_airflow_modules,
)
from airflow.models import DagBag, TaskInstance
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -141,8 +142,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
assert "a.py" in resp.import_errors

def test_top_level_variable_access(
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
self,
spy_agency: SpyAgency,
tmp_path: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
inprocess_client,
):
logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -159,6 +165,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand All @@ -172,8 +179,13 @@ def dag_in_a_fn():
assert result.serialized_dags[0].dag_id == "test_abc"

def test_top_level_variable_access_not_found(
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
self,
spy_agency: SpyAgency,
tmp_path: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
inprocess_client,
):
logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -188,6 +200,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand All @@ -204,6 +217,7 @@ def dag_in_a_fn():
def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client):
from airflow.models.variable import Variable as VariableORM

logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -219,6 +233,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand All @@ -239,6 +254,7 @@ def dag_in_a_fn():
def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client):
from airflow.models.variable import Variable as VariableORM

logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -260,6 +276,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand All @@ -279,6 +296,7 @@ def dag_in_a_fn():
def test_top_level_connection_access(
self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
):
logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -295,6 +313,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand All @@ -308,6 +327,7 @@ def dag_in_a_fn():
assert result.serialized_dags[0].dag_id == "test_my_conn"

def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client):
logger = MagicMock()
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -322,6 +342,7 @@ def dag_in_a_fn():
path=path,
bundle_path=tmp_path,
callbacks=[],
logger=logger,
logger_filehandle=logger_filehandle,
client=inprocess_client,
)
Expand Down Expand Up @@ -354,6 +375,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl
path=dag1_path,
bundle_path=tmp_path,
callbacks=[],
logger=MagicMock(),
logger_filehandle=MagicMock(),
client=inprocess_client,
)
Expand All @@ -365,6 +387,65 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl
assert result.import_errors == {}
assert result.serialized_dags[0].dag_id == "dag_name"

def test__pre_import_airflow_modules_when_disabled(self):
logger = MagicMock()
with (
env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "false"}),
patch("airflow.dag_processing.processor.iter_airflow_imports") as mock_iter,
):
_pre_import_airflow_modules("test.py", logger)

mock_iter.assert_not_called()
logger.warning.assert_not_called()

def test__pre_import_airflow_modules_when_enabled(self):
logger = MagicMock()
with (
env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}),
patch("airflow.dag_processing.processor.iter_airflow_imports", return_value=["airflow.models"]),
patch("airflow.dag_processing.processor.importlib.import_module") as mock_import,
):
_pre_import_airflow_modules("test.py", logger)

mock_import.assert_called_once_with("airflow.models")
logger.warning.assert_not_called()

def test__pre_import_airflow_modules_warns_on_missing_module(self):
logger = MagicMock()
with (
env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}),
patch(
"airflow.dag_processing.processor.iter_airflow_imports", return_value=["non_existent_module"]
),
patch(
"airflow.dag_processing.processor.importlib.import_module", side_effect=ModuleNotFoundError()
),
):
_pre_import_airflow_modules("test.py", logger)

logger.warning.assert_called_once()
warning_args = logger.warning.call_args[0]
assert "Error when trying to pre-import module" in warning_args[0]
assert "non_existent_module" in warning_args[1]
assert "test.py" in warning_args[2]

def test__pre_import_airflow_modules_partial_success_and_warning(self):
logger = MagicMock()
with (
env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}),
patch(
"airflow.dag_processing.processor.iter_airflow_imports",
return_value=["airflow.models", "non_existent_module"],
),
patch(
"airflow.dag_processing.processor.importlib.import_module",
side_effect=[None, ModuleNotFoundError()],
),
):
_pre_import_airflow_modules("test.py", logger)

assert logger.warning.call_count == 1


def write_dag_in_a_fn_to_file(fn: Callable[[], None], folder: pathlib.Path) -> pathlib.Path:
# Create the dag in a fn, and use inspect.getsource to write it to a file so that
Expand Down