Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.dag_processing.dagbag import BundleDagBag, DagBag, sync_bag_to_db
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.jobs.job import Job
from airflow.models import DagModel, DagRun, TaskInstance
Expand Down Expand Up @@ -378,10 +378,12 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None:

for bundle in all_bundles:
if bundle.name in bundles_to_search:
dagbag = DagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
dagbag.collect_dags()
dags_list.extend(list(dagbag.dags.values()))
dagbag_import_errors += len(dagbag.import_errors)
bundle_dagbag = BundleDagBag(
bundle.path, bundle_path=bundle.path, bundle_name=bundle.name
)
bundle_dagbag.collect_dags()
dags_list.extend(list(bundle_dagbag.dags.values()))
dagbag_import_errors += len(bundle_dagbag.import_errors)
else:
dagbag = DagBag()
dagbag.collect_dags()
Expand Down Expand Up @@ -474,8 +476,10 @@ def dag_list_import_errors(args, session: Session = NEW_SESSION) -> None:

for bundle in all_bundles:
if bundle.name in bundles_to_search:
dagbag = DagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
for filename, errors in dagbag.import_errors.items():
bundle_dagbag = BundleDagBag(
bundle.path, bundle_path=bundle.path, bundle_name=bundle.name
)
for filename, errors in bundle_dagbag.import_errors.items():
data.append({"bundle_name": bundle.name, "filepath": filename, "error": errors})
else:
dagbag = DagBag()
Expand Down Expand Up @@ -526,7 +530,7 @@ def dag_report(args) -> None:
if bundle.name not in bundles_to_reserialize:
continue
bundle.initialize()
dagbag = DagBag(bundle.path, bundle_name=bundle.name, include_examples=False)
dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
all_dagbag_stats.extend(dagbag.dagbag_stats)

AirflowConsole().print_as(
Expand Down Expand Up @@ -690,7 +694,5 @@ def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
if bundle.name not in bundles_to_reserialize:
continue
bundle.initialize()
dag_bag = DagBag(
bundle.path, bundle_path=bundle.path, bundle_name=bundle.name, include_examples=False
)
dag_bag = BundleDagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
sync_bag_to_db(dag_bag, bundle.name, bundle_version=bundle.get_current_version(), session=session)
34 changes: 34 additions & 0 deletions airflow-core/src/airflow/dag_processing/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,40 @@ def dagbag_report(self):
return report


class BundleDagBag(DagBag):
"""
Bundle-aware DagBag that permanently modifies sys.path.

This class adds the bundle_path to sys.path permanently to allow DAG files
to import modules from their bundle directory. No cleanup is performed.

WARNING: Only use for one-off usages like CLI commands. Using this in long-running
processes will cause sys.path to accumulate entries.

Same parameters as DagBag, but bundle_path is required and examples are not loaded.
"""

def __init__(self, *args, bundle_path: Path | None = None, **kwargs):
if not bundle_path:
raise ValueError("bundle_path is required for BundleDagBag")

if str(bundle_path) not in sys.path:
sys.path.append(str(bundle_path))

# Warn if user explicitly set include_examples=True, since bundles never contain examples
if kwargs.get("include_examples") is True:
warnings.warn(
"include_examples=True is ignored for BundleDagBag. "
"Bundles do not contain example DAGs, so include_examples is always False.",
UserWarning,
stacklevel=2,
)

kwargs["bundle_path"] = bundle_path
kwargs["include_examples"] = False
super().__init__(*args, **kwargs)


@provide_session
def sync_bag_to_db(
dagbag: DagBag,
Expand Down
11 changes: 2 additions & 9 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import importlib
import os
import sys
import traceback
from collections.abc import Callable, Sequence
from pathlib import Path
Expand All @@ -35,7 +34,7 @@
TaskCallbackRequest,
)
from airflow.configuration import conf
from airflow.dag_processing.dagbag import DagBag
from airflow.dag_processing.dagbag import BundleDagBag, DagBag
from airflow.observability.stats import Stats
from airflow.sdk.exceptions import TaskNotFound
from airflow.sdk.execution_time.comms import (
Expand Down Expand Up @@ -198,11 +197,6 @@ def _parse_file_entrypoint():
task_runner.SUPERVISOR_COMMS = comms_decoder
log = structlog.get_logger(logger_name="task")

# Put bundle root on sys.path if needed. This allows the dag bundle to add
# code in util modules to be shared between files within the same bundle.
if (bundle_root := os.fspath(msg.bundle_path)) not in sys.path:
sys.path.append(bundle_root)

result = _parse_file(msg, log)
if result is not None:
comms_decoder.send(result)
Expand All @@ -211,11 +205,10 @@ def _parse_file_entrypoint():
def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None:
# TODO: Set known_pool names on DagBag!

bag = DagBag(
bag = BundleDagBag(
dag_folder=msg.file,
bundle_path=msg.bundle_path,
bundle_name=msg.bundle_name,
include_examples=False,
load_op_links=False,
)
if msg.callback_requests:
Expand Down
12 changes: 5 additions & 7 deletions airflow-core/src/airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,17 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N
find the correct path (assuming it's a file) and failing that, use the configured
dags folder.
"""
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.dag_processing.dagbag import BundleDagBag, sync_bag_to_db
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager

manager = DagBundlesManager()
for bundle_name in bundle_names or ():
bundle = manager.get_bundle(bundle_name)
with _airflow_parsing_context_manager(dag_id=dag_id):
dagbag = DagBag(
dagbag = BundleDagBag(
dag_folder=dagfile_path or bundle.path,
bundle_path=bundle.path,
bundle_name=bundle.name,
include_examples=False,
)
if dag := dagbag.dags.get(dag_id):
return dag
Expand All @@ -292,11 +291,10 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N
for bundle in manager.get_all_dag_bundles():
bundle.initialize()
with _airflow_parsing_context_manager(dag_id=dag_id):
dagbag = DagBag(
dagbag = BundleDagBag(
dag_folder=dagfile_path or bundle.path,
bundle_path=bundle.path,
bundle_name=bundle.name,
include_examples=False,
)
sync_bag_to_db(dagbag, bundle.name, bundle.version)
if dag := dagbag.dags.get(dag_id):
Expand All @@ -323,7 +321,7 @@ def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None

def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, from_db: bool = False):
"""Return DAG(s) matching a given regex or dag_id."""
from airflow.dag_processing.dagbag import DagBag
from airflow.dag_processing.dagbag import BundleDagBag

bundle_names = bundle_names or []

Expand All @@ -333,7 +331,7 @@ def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, fr
return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)]

def _find_dag(bundle):
dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
dagbag = BundleDagBag(dag_folder=bundle.path, bundle_path=bundle.path, bundle_name=bundle.name)
matched_dags = [dag for dag in dagbag.dags.values() if re.search(dag_id, dag.dag_id)]
return matched_dags

Expand Down
9 changes: 3 additions & 6 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag, stdout_capture):
mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])])
assert "SOURCE" in output

@mock.patch("airflow.dag_processing.dagbag.DagBag")
@mock.patch("airflow.dag_processing.dagbag.BundleDagBag")
def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using bundle name."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
Expand All @@ -785,10 +785,9 @@ def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles):
bundle_path=TEST_DAGS_FOLDER,
dag_folder=TEST_DAGS_FOLDER,
bundle_name="testing",
include_examples=False,
)

@mock.patch("airflow.dag_processing.dagbag.DagBag")
@mock.patch("airflow.dag_processing.dagbag.BundleDagBag")
def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using dagfile path."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
Expand All @@ -807,10 +806,9 @@ def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles):
bundle_path=TEST_DAGS_FOLDER,
dag_folder=str(dag_file),
bundle_name="testing",
include_examples=False,
)

@mock.patch("airflow.dag_processing.dagbag.DagBag")
@mock.patch("airflow.dag_processing.dagbag.BundleDagBag")
def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure_dag_bundles):
"""Test that DAG can be tested using both bundle name and dagfile path."""
mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun(
Expand Down Expand Up @@ -839,7 +837,6 @@ def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure
bundle_path=TEST_DAGS_FOLDER,
dag_folder=str(dag_file),
bundle_name="testing",
include_examples=False,
)

@mock.patch("airflow.models.dagrun.get_or_create_dagrun")
Expand Down
100 changes: 99 additions & 1 deletion airflow-core/tests/unit/dag_processing/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from sqlalchemy import select

from airflow import settings
from airflow.dag_processing.dagbag import DagBag, _capture_with_reraise, _validate_executor_fields
from airflow.dag_processing.dagbag import (
BundleDagBag,
DagBag,
_capture_with_reraise,
_validate_executor_fields,
)
from airflow.exceptions import UnknownExecutorException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -1192,3 +1197,96 @@ def test_capture_warnings_with_error_filters(self):
self.raise_warnings()
assert len(cw) == 1
assert len(records) == 1


class TestBundlePathSysPath:
"""Tests for bundle_path sys.path handling in BundleDagBag."""

def test_bundle_path_added_to_syspath(self, tmp_path):
"""Test that BundleDagBag adds bundle_path to sys.path when provided."""
util_file = tmp_path / "bundle_util.py"
util_file.write_text('def get_message(): return "Hello from bundle!"')

dag_file = tmp_path / "test_dag.py"
dag_file.write_text(
textwrap.dedent(
"""\
from airflow.sdk import DAG
from airflow.operators.empty import EmptyOperator

import sys
import bundle_util

with DAG('test_import', description=f"DAG with sys.path: {sys.path}"):
EmptyOperator(task_id="mytask")
"""
)
)

assert str(tmp_path) not in sys.path

dagbag = BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path, bundle_name="test-bundle")

# Check import was successful
assert len(dagbag.dags) == 1
assert not dagbag.import_errors

dag = dagbag.get_dag("test_import")
assert dag is not None
assert str(tmp_path) in dag.description # sys.path was enhanced during parse

# Path remains in sys.path (no cleanup - intentional for ephemeral processes)
assert str(tmp_path) in sys.path

# Cleanup for other tests
sys.path.remove(str(tmp_path))

def test_bundle_path_not_duplicated(self, tmp_path):
"""Test that bundle_path is not added to sys.path if already present."""
dag_file = tmp_path / "simple_dag.py"
dag_file.write_text(
textwrap.dedent(
"""\
from airflow.sdk import DAG
from airflow.operators.empty import EmptyOperator

with DAG("simple_dag"):
EmptyOperator(task_id="mytask")
"""
)
)

# Pre-add the path
sys.path.append(str(tmp_path))
count_before = sys.path.count(str(tmp_path))

BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path, bundle_name="test-bundle")

# Should not add duplicate
assert sys.path.count(str(tmp_path)) == count_before

# Cleanup for other tests
sys.path.remove(str(tmp_path))

def test_dagbag_no_bundle_path_no_syspath_modification(self, tmp_path):
"""Test that no sys.path modification occurs when DagBag is used without bundle_path."""
dag_file = tmp_path / "simple_dag.py"
dag_file.write_text(
textwrap.dedent(
"""\
from airflow.sdk import DAG
from airflow.operators.empty import EmptyOperator

import sys

with DAG("simple_dag", description=f"DAG with sys.path: {sys.path}") as dag:
EmptyOperator(task_id="mytask")
"""
)
)
syspath_before = deepcopy(sys.path)
dagbag = DagBag(dag_folder=str(dag_file), include_examples=False)
dag = dagbag.get_dag("simple_dag")

assert str(tmp_path) not in dag.description
assert sys.path == syspath_before
8 changes: 4 additions & 4 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,9 +1795,9 @@ def fake_collect_dags(self, *args, **kwargs):
_execute_email_callbacks(dagbag, request, log)

def test_parse_file_passes_bundle_name_to_dagbag(self):
"""Test that _parse_file() creates DagBag with correct bundle_name parameter"""
# Mock the DagBag constructor to capture its arguments
with patch("airflow.dag_processing.processor.DagBag") as mock_dagbag_class:
"""Test that _parse_file() creates BundleDagBag with correct bundle_name parameter"""
# Mock the BundleDagBag constructor to capture its arguments
with patch("airflow.dag_processing.processor.BundleDagBag") as mock_dagbag_class:
# Create a mock instance with proper attributes for Pydantic validation
mock_dagbag_instance = MagicMock()
mock_dagbag_instance.dags = {}
Expand All @@ -1813,7 +1813,7 @@ def test_parse_file_passes_bundle_name_to_dagbag(self):

_parse_file(request, log=structlog.get_logger())

# Verify DagBag was called with correct bundle_name
# Verify BundleDagBag was called with correct bundle_name
mock_dagbag_class.assert_called_once()
call_kwargs = mock_dagbag_class.call_args.kwargs
assert call_kwargs["bundle_name"] == "test_bundle"
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/tests/unit/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from airflow._shared.timezones import timezone
from airflow._shared.timezones.timezone import datetime as datetime_tz
from airflow.configuration import conf
from airflow.dag_processing.dagbag import DagBag
from airflow.dag_processing.dagbag import BundleDagBag, DagBag
from airflow.exceptions import AirflowException
from airflow.models.asset import (
AssetAliasModel,
Expand Down Expand Up @@ -2200,7 +2200,7 @@ def test_relative_fileloc(self, session, testing_dag_bundle):
rel_path = "test_assets.py"
bundle_path = TEST_DAGS_FOLDER
file_path = bundle_path / rel_path
bag = DagBag(dag_folder=file_path, bundle_path=bundle_path)
bag = BundleDagBag(dag_folder=file_path, bundle_path=bundle_path, bundle_name="testing")

dag = bag.get_dag("dag_with_skip_task")

Expand Down
Loading