diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 6acbfa072cc57..2c3897941d429 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -37,7 +37,7 @@ from airflow.cli.simple_table import AirflowConsole from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db +from airflow.dag_processing.dagbag import BundleDagBag, DagBag, sync_bag_to_db from airflow.exceptions import AirflowConfigException, AirflowException from airflow.jobs.job import Job from airflow.models import DagModel, DagRun, TaskInstance @@ -378,10 +378,12 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None: for bundle in all_bundles: if bundle.name in bundles_to_search: - dagbag = DagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) - dagbag.collect_dags() - dags_list.extend(list(dagbag.dags.values())) - dagbag_import_errors += len(dagbag.import_errors) + bundle_dagbag = BundleDagBag( + bundle.path, bundle_path=bundle.path, bundle_name=bundle.name + ) + bundle_dagbag.collect_dags() + dags_list.extend(list(bundle_dagbag.dags.values())) + dagbag_import_errors += len(bundle_dagbag.import_errors) else: dagbag = DagBag() dagbag.collect_dags() @@ -474,8 +476,10 @@ def dag_list_import_errors(args, session: Session = NEW_SESSION) -> None: for bundle in all_bundles: if bundle.name in bundles_to_search: - dagbag = DagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) - for filename, errors in dagbag.import_errors.items(): + bundle_dagbag = BundleDagBag( + bundle.path, bundle_path=bundle.path, bundle_name=bundle.name + ) + for filename, errors in bundle_dagbag.import_errors.items(): data.append({"bundle_name": bundle.name, "filepath": filename, "error": errors}) else: dagbag = DagBag() @@ -526,7 +530,7 @@ def dag_report(args) -> None: if bundle.name not in bundles_to_reserialize: continue bundle.initialize() - dagbag = DagBag(bundle.path, bundle_name=bundle.name, include_examples=False) + dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) all_dagbag_stats.extend(dagbag.dagbag_stats) AirflowConsole().print_as( @@ -690,7 +694,5 @@ def dag_reserialize(args, session: Session = NEW_SESSION) -> None: if bundle.name not in bundles_to_reserialize: continue bundle.initialize() - dag_bag = DagBag( - bundle.path, bundle_path=bundle.path, bundle_name=bundle.name, include_examples=False - ) + dag_bag = BundleDagBag(bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) sync_bag_to_db(dag_bag, bundle.name, bundle_version=bundle.get_current_version(), session=session) diff --git a/airflow-core/src/airflow/dag_processing/dagbag.py b/airflow-core/src/airflow/dag_processing/dagbag.py index 8263084276507..524af63bef74c 100644 --- a/airflow-core/src/airflow/dag_processing/dagbag.py +++ b/airflow-core/src/airflow/dag_processing/dagbag.py @@ -695,6 +695,40 @@ def dagbag_report(self): return report +class BundleDagBag(DagBag): + """ + Bundle-aware DagBag that permanently modifies sys.path. + + This class adds the bundle_path to sys.path permanently to allow DAG files + to import modules from their bundle directory. No cleanup is performed. + + WARNING: Only use for one-off usages like CLI commands. Using this in long-running + processes will cause sys.path to accumulate entries. + + Same parameters as DagBag, but bundle_path is required and examples are not loaded. + """ + + def __init__(self, *args, bundle_path: Path | None = None, **kwargs): + if not bundle_path: + raise ValueError("bundle_path is required for BundleDagBag") + + if str(bundle_path) not in sys.path: + sys.path.append(str(bundle_path)) + + # Warn if user explicitly set include_examples=True, since bundles never contain examples + if kwargs.get("include_examples") is True: + warnings.warn( + "include_examples=True is ignored for BundleDagBag. " + "Bundles do not contain example DAGs, so include_examples is always False.", + UserWarning, + stacklevel=2, + ) + + kwargs["bundle_path"] = bundle_path + kwargs["include_examples"] = False + super().__init__(*args, **kwargs) + + @provide_session def sync_bag_to_db( dagbag: DagBag, diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 82711527803dc..77bc71df3c464 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -19,7 +19,6 @@ import contextlib import importlib import os -import sys import traceback from collections.abc import Callable, Sequence from pathlib import Path @@ -35,7 +34,7 @@ TaskCallbackRequest, ) from airflow.configuration import conf -from airflow.dag_processing.dagbag import DagBag +from airflow.dag_processing.dagbag import BundleDagBag, DagBag from airflow.observability.stats import Stats from airflow.sdk.exceptions import TaskNotFound from airflow.sdk.execution_time.comms import ( @@ -198,11 +197,6 @@ def _parse_file_entrypoint(): task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") - # Put bundle root on sys.path if needed. This allows the dag bundle to add - # code in util modules to be shared between files within the same bundle. - if (bundle_root := os.fspath(msg.bundle_path)) not in sys.path: - sys.path.append(bundle_root) - result = _parse_file(msg, log) if result is not None: comms_decoder.send(result) @@ -211,11 +205,10 @@ def _parse_file_entrypoint(): def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: # TODO: Set known_pool names on DagBag! - bag = DagBag( + bag = BundleDagBag( dag_folder=msg.file, bundle_path=msg.bundle_path, bundle_name=msg.bundle_name, - include_examples=False, load_op_links=False, ) if msg.callback_requests: diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index 20455f38e33e1..4df33b26894ff 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -272,18 +272,17 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N find the correct path (assuming it's a file) and failing that, use the configured dags folder. """ - from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db + from airflow.dag_processing.dagbag import BundleDagBag, sync_bag_to_db from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager manager = DagBundlesManager() for bundle_name in bundle_names or (): bundle = manager.get_bundle(bundle_name) with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( + dagbag = BundleDagBag( dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, bundle_name=bundle.name, - include_examples=False, ) if dag := dagbag.dags.get(dag_id): return dag @@ -292,11 +291,10 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N for bundle in manager.get_all_dag_bundles(): bundle.initialize() with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( + dagbag = BundleDagBag( dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, bundle_name=bundle.name, - include_examples=False, ) sync_bag_to_db(dagbag, bundle.name, bundle.version) if dag := dagbag.dags.get(dag_id): @@ -323,7 +321,7 @@ def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, from_db: bool = False): """Return DAG(s) matching a given regex or dag_id.""" - from airflow.dag_processing.dagbag import DagBag + from airflow.dag_processing.dagbag import BundleDagBag bundle_names = bundle_names or [] @@ -333,7 +331,7 @@ def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, fr return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)] def _find_dag(bundle): - dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) + dagbag = BundleDagBag(dag_folder=bundle.path, bundle_path=bundle.path, bundle_name=bundle.name) matched_dags = [dag for dag in dagbag.dags.values() if re.search(dag_id, dag.dag_id)] return matched_dags diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 30a751c82791b..643248528fab0 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -760,7 +760,7 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag, stdout_capture): mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])]) assert "SOURCE" in output - @mock.patch("airflow.dag_processing.dagbag.DagBag") + @mock.patch("airflow.dag_processing.dagbag.BundleDagBag") def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using bundle name.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -785,10 +785,9 @@ def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles): bundle_path=TEST_DAGS_FOLDER, dag_folder=TEST_DAGS_FOLDER, bundle_name="testing", - include_examples=False, ) - @mock.patch("airflow.dag_processing.dagbag.DagBag") + @mock.patch("airflow.dag_processing.dagbag.BundleDagBag") def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using dagfile path.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -807,10 +806,9 @@ def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles): bundle_path=TEST_DAGS_FOLDER, dag_folder=str(dag_file), bundle_name="testing", - include_examples=False, ) - @mock.patch("airflow.dag_processing.dagbag.DagBag") + @mock.patch("airflow.dag_processing.dagbag.BundleDagBag") def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using both bundle name and dagfile path.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -839,7 +837,6 @@ def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure bundle_path=TEST_DAGS_FOLDER, dag_folder=str(dag_file), bundle_name="testing", - include_examples=False, ) @mock.patch("airflow.models.dagrun.get_or_create_dagrun") diff --git a/airflow-core/tests/unit/dag_processing/test_dagbag.py b/airflow-core/tests/unit/dag_processing/test_dagbag.py index 247935d464fc3..fafde1b3a9b08 100644 --- a/airflow-core/tests/unit/dag_processing/test_dagbag.py +++ b/airflow-core/tests/unit/dag_processing/test_dagbag.py @@ -35,7 +35,12 @@ from sqlalchemy import select from airflow import settings -from airflow.dag_processing.dagbag import DagBag, _capture_with_reraise, _validate_executor_fields +from airflow.dag_processing.dagbag import ( + BundleDagBag, + DagBag, + _capture_with_reraise, + _validate_executor_fields, +) from airflow.exceptions import UnknownExecutorException from airflow.executors.executor_loader import ExecutorLoader from airflow.models.dag import DagModel @@ -1192,3 +1197,96 @@ def test_capture_warnings_with_error_filters(self): self.raise_warnings() assert len(cw) == 1 assert len(records) == 1 + + +class TestBundlePathSysPath: + """Tests for bundle_path sys.path handling in BundleDagBag.""" + + def test_bundle_path_added_to_syspath(self, tmp_path): + """Test that BundleDagBag adds bundle_path to sys.path when provided.""" + util_file = tmp_path / "bundle_util.py" + util_file.write_text('def get_message(): return "Hello from bundle!"') + + dag_file = tmp_path / "test_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + import sys + import bundle_util + + with DAG('test_import', description=f"DAG with sys.path: {sys.path}"): + EmptyOperator(task_id="mytask") + """ + ) + ) + + assert str(tmp_path) not in sys.path + + dagbag = BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path, bundle_name="test-bundle") + + # Check import was successful + assert len(dagbag.dags) == 1 + assert not dagbag.import_errors + + dag = dagbag.get_dag("test_import") + assert dag is not None + assert str(tmp_path) in dag.description # sys.path was enhanced during parse + + # Path remains in sys.path (no cleanup - intentional for ephemeral processes) + assert str(tmp_path) in sys.path + + # Cleanup for other tests + sys.path.remove(str(tmp_path)) + + def test_bundle_path_not_duplicated(self, tmp_path): + """Test that bundle_path is not added to sys.path if already present.""" + dag_file = tmp_path / "simple_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + with DAG("simple_dag"): + EmptyOperator(task_id="mytask") + """ + ) + ) + + # Pre-add the path + sys.path.append(str(tmp_path)) + count_before = sys.path.count(str(tmp_path)) + + BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path, bundle_name="test-bundle") + + # Should not add duplicate + assert sys.path.count(str(tmp_path)) == count_before + + # Cleanup for other tests + sys.path.remove(str(tmp_path)) + + def test_dagbag_no_bundle_path_no_syspath_modification(self, tmp_path): + """Test that no sys.path modification occurs when DagBag is used without bundle_path.""" + dag_file = tmp_path / "simple_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + import sys + + with DAG("simple_dag", description=f"DAG with sys.path: {sys.path}") as dag: + EmptyOperator(task_id="mytask") + """ + ) + ) + syspath_before = deepcopy(sys.path) + dagbag = DagBag(dag_folder=str(dag_file), include_examples=False) + dag = dagbag.get_dag("simple_dag") + + assert str(tmp_path) not in dag.description + assert sys.path == syspath_before diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1348a428d5fcb..960b8440902de 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1795,9 +1795,9 @@ def fake_collect_dags(self, *args, **kwargs): _execute_email_callbacks(dagbag, request, log) def test_parse_file_passes_bundle_name_to_dagbag(self): - """Test that _parse_file() creates DagBag with correct bundle_name parameter""" - # Mock the DagBag constructor to capture its arguments - with patch("airflow.dag_processing.processor.DagBag") as mock_dagbag_class: + """Test that _parse_file() creates BundleDagBag with correct bundle_name parameter""" + # Mock the BundleDagBag constructor to capture its arguments + with patch("airflow.dag_processing.processor.BundleDagBag") as mock_dagbag_class: # Create a mock instance with proper attributes for Pydantic validation mock_dagbag_instance = MagicMock() mock_dagbag_instance.dags = {} @@ -1813,7 +1813,7 @@ def test_parse_file_passes_bundle_name_to_dagbag(self): _parse_file(request, log=structlog.get_logger()) - # Verify DagBag was called with correct bundle_name + # Verify BundleDagBag was called with correct bundle_name mock_dagbag_class.assert_called_once() call_kwargs = mock_dagbag_class.call_args.kwargs assert call_kwargs["bundle_name"] == "test_bundle" diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index a6889d782cb7d..78af76114d2c1 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -40,7 +40,7 @@ from airflow._shared.timezones import timezone from airflow._shared.timezones.timezone import datetime as datetime_tz from airflow.configuration import conf -from airflow.dag_processing.dagbag import DagBag +from airflow.dag_processing.dagbag import BundleDagBag, DagBag from airflow.exceptions import AirflowException from airflow.models.asset import ( AssetAliasModel, @@ -2200,7 +2200,7 @@ def test_relative_fileloc(self, session, testing_dag_bundle): rel_path = "test_assets.py" bundle_path = TEST_DAGS_FOLDER file_path = bundle_path / rel_path - bag = DagBag(dag_folder=file_path, bundle_path=bundle_path) + bag = BundleDagBag(dag_folder=file_path, bundle_path=bundle_path, bundle_name="testing") dag = bag.get_dag("dag_with_skip_task") diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 788cf6c63dc7a..c01af7c691dad 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1244,7 +1244,7 @@ def test( version = DagVersion.get_version(self.dag_id) if not version: from airflow.dag_processing.bundles.manager import DagBundlesManager - from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db + from airflow.dag_processing.dagbag import BundleDagBag, sync_bag_to_db from airflow.sdk.definitions._internal.dag_parsing_context import ( _airflow_parsing_context_manager, ) @@ -1258,8 +1258,10 @@ def test( if not bundle.is_initialized: bundle.initialize() with _airflow_parsing_context_manager(dag_id=self.dag_id): - dagbag = DagBag( - dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False + dagbag = BundleDagBag( + dag_folder=bundle.path, + bundle_path=bundle.path, + bundle_name=bundle.name, ) sync_bag_to_db(dagbag, bundle.name, bundle.version) version = DagVersion.get_version(self.dag_id) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 814673d93fe27..a52670c0ba0b8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -671,8 +671,8 @@ def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # TODO: Task-SDK: - # Using DagBag here is about 98% wrong, but it'll do for now - from airflow.dag_processing.dagbag import DagBag + # Using BundleDagBag here is about 98% wrong, but it'll do for now + from airflow.dag_processing.dagbag import BundleDagBag bundle_info = what.bundle_info bundle_instance = DagBundlesManager().get_bundle( @@ -681,17 +681,12 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) bundle_instance.initialize() - # Put bundle root on sys.path if needed. This allows the dag bundle to add - # code in util modules to be shared between files within the same bundle. - if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path: - sys.path.append(bundle_root) - dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) - bag = DagBag( + bag = BundleDagBag( dag_folder=dag_absolute_path, - include_examples=False, safe_mode=False, load_op_links=False, + bundle_path=bundle_instance.path, bundle_name=bundle_info.name, ) if TYPE_CHECKING: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 4a5c54ba3a467..1dee447ac101c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -197,9 +197,9 @@ def test_parse(test_dags_dir: Path, make_ti_context): assert ti.task.dag -@mock.patch("airflow.dag_processing.dagbag.DagBag") +@mock.patch("airflow.dag_processing.dagbag.BundleDagBag") def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): - """Test that checks that the dagbag is constructed as expected during parsing""" + """Test that checks that the BundleDagBag is constructed as expected during parsing""" mock_bag_instance = mock.Mock() mock_dagbag.return_value = mock_bag_instance mock_dag = mock.Mock(spec=DAG) @@ -242,9 +242,9 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): mock_dagbag.assert_called_once_with( dag_folder=mock.ANY, - include_examples=False, safe_mode=False, load_op_links=False, + bundle_path=test_dags_dir, bundle_name="my-bundle", )