Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import contextlib
import functools
import importlib
import inspect
import logging
import os
Expand Down Expand Up @@ -78,6 +79,7 @@

from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from structlog.typing import FilteringBoundLogger as Logger

from airflow.callbacks.callback_requests import CallbackRequest
from airflow.dag_processing.bundles.base import BaseDagBundle
Expand Down Expand Up @@ -859,6 +861,7 @@ def _collect_results(self, session: Session = NEW_SESSION):
session=session,
is_callback_only=is_callback_only,
relative_fileloc=str(file.rel_path),
log=self.log,
)

for file in finished:
Expand Down Expand Up @@ -1149,6 +1152,7 @@ def process_parse_results(
*,
is_callback_only: bool = False,
relative_fileloc: str | None = None,
log: Logger | None = None,
) -> DagFileStat:
"""Take the parsing result and stats about the parser process and convert it into a DagFileStat."""
if is_callback_only:
Expand Down Expand Up @@ -1204,4 +1208,13 @@ def process_parse_results(
stat.num_dags = len(parsing_result.serialized_dags)
if parsing_result.import_errors:
stat.import_errors = len(parsing_result.import_errors)

# Load Airflow modules that were not loaded during runtime.
if conf.getboolean("dag_processor", "parsing_pre_import_modules", fallback=True) and log:
for module in parsing_result.not_loaded_airflow_modules or []:
try:
if module not in sys.modules.keys():
importlib.import_module(module)
except Exception:
log.warning("Error when trying to pre-import module '%s'", module)
return stat
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class DagFileParsingResult(BaseModel):
warnings: list | None = None
import_errors: dict[str, str] | None = None
type: Literal["DagFileParsingResult"] = "DagFileParsingResult"
not_loaded_airflow_modules: list[str] | None = None


ToManager = Annotated[
Expand Down Expand Up @@ -206,6 +207,7 @@ def _parse_file_entrypoint():

def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None:
# TODO: Set known_pool names on DagBag!
modules_before = set(sys.modules.keys())

bag = DagBag(
dag_folder=msg.file,
Expand All @@ -221,12 +223,17 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP

serialized_dags, serialization_import_errors = _serialize_dags(bag, log)
bag.import_errors.update(serialization_import_errors)

modules_after = set(sys.modules.keys())
not_loaded_airflow_modules = {m for m in (modules_after - modules_before) if m.startswith("airflow.")}

result = DagFileParsingResult(
fileloc=msg.file,
serialized_dags=serialized_dags,
import_errors=bag.import_errors,
# TODO: Make `bag.dag_warnings` not return SQLA model objects
warnings=[],
not_loaded_airflow_modules=list(not_loaded_airflow_modules),
)
return result

Expand Down
56 changes: 56 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,38 @@ def fake_collect_dags(self, *args, **kwargs):
assert called is True


def test_parse_file_tracks_newly_imported_airflow_modules(tmp_path):
"""Test that newly imported airflow modules are tracked in not_import_module"""
dag_file = tmp_path / "test_dag.py"
dag_file.write_text("""
from airflow import DAG
import airflow.decorators
from datetime import datetime

dag = DAG(
dag_id="test_dag",
start_date=datetime(2024, 1, 1),
)
""")

if "airflow.decorators" in sys.modules:
del sys.modules["airflow.decorators"]

result = _parse_file(
DagFileParseRequest(
file=str(dag_file),
bundle_path=str(tmp_path),
bundle_name="testing",
callback_requests=[],
),
log=structlog.get_logger(),
)

assert result is not None
assert "airflow.decorators" in result.not_loaded_airflow_modules
assert all(m.startswith("airflow.") for m in result.not_loaded_airflow_modules)


def test_parse_file_with_task_callbacks(spy_agency):
called = False

Expand Down Expand Up @@ -645,6 +677,30 @@ def test_normal_parsing_updates_timestamps(session):
assert stat.import_errors == 0


def test_import_not_imported_module(session):
"""test the function actually load the not loaded airflow module"""
if "airflow.decorators" in sys.modules:
del sys.modules["airflow.decorators"]

finish_time = timezone.utcnow()

process_parse_results(
run_duration=2.0,
finish_time=finish_time,
run_count=3,
bundle_name="test-bundle",
bundle_version="v1",
parsing_result=DagFileParsingResult(
fileloc="test.py", serialized_dags=[], not_loaded_airflow_modules=["airflow.decorators"]
),
session=session,
is_callback_only=False,
log=structlog.get_logger(),
)

assert "airflow.decorators" in sys.modules.keys()


def test_import_error_updates_timestamps(session):
"""last_finish_time should be updated when parsing a dag file results in import errors."""
finish_time = timezone.utcnow()
Expand Down