diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index df7276d86437d..2a7d6d2471020 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -41,7 +41,7 @@ from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DagRun, TaskInstance from airflow.models.dag import get_next_data_interval -from airflow.models.dagbag import sync_bag_to_db +from airflow.models.dagbag import BundleDagBag, sync_bag_to_db from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import cli as cli_utils @@ -375,10 +375,10 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None: for bundle in all_bundles: if bundle.name in bundles_to_search: - dagbag = DagBag(bundle.path, bundle_path=bundle.path) - dagbag.collect_dags() - dags_list.extend(list(dagbag.dags.values())) - dagbag_import_errors += len(dagbag.import_errors) + bundle_dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path) + bundle_dagbag.collect_dags() + dags_list.extend(list(bundle_dagbag.dags.values())) + dagbag_import_errors += len(bundle_dagbag.import_errors) else: dagbag = DagBag() dagbag.collect_dags() @@ -471,8 +471,8 @@ def dag_list_import_errors(args, session: Session = NEW_SESSION) -> None: for bundle in all_bundles: if bundle.name in bundles_to_search: - dagbag = DagBag(bundle.path, bundle_path=bundle.path) - for filename, errors in dagbag.import_errors.items(): + bundle_dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path) + for filename, errors in bundle_dagbag.import_errors.items(): data.append({"bundle_name": bundle.name, "filepath": filename, "error": errors}) else: dagbag = DagBag() @@ -523,7 +523,7 @@ def dag_report(args) -> None: if bundle.name not in bundles_to_reserialize: continue bundle.initialize() - dagbag = DagBag(bundle.path, include_examples=False) + dagbag = BundleDagBag(bundle.path, bundle_path=bundle.path) all_dagbag_stats.extend(dagbag.dagbag_stats) AirflowConsole().print_as( @@ -687,5 +687,5 @@ def dag_reserialize(args, session: Session = NEW_SESSION) -> None: if bundle.name not in bundles_to_reserialize: continue bundle.initialize() - dag_bag = DagBag(bundle.path, bundle_path=bundle.path, include_examples=False) + dag_bag = BundleDagBag(bundle.path, bundle_path=bundle.path) sync_bag_to_db(dag_bag, bundle.name, bundle_version=bundle.get_current_version(), session=session) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index ac30dc03358b7..49f997f9a1e6e 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -19,7 +19,6 @@ import contextlib import importlib import os -import sys import traceback from collections.abc import Callable, Sequence from pathlib import Path @@ -36,7 +35,7 @@ ) from airflow.configuration import conf from airflow.exceptions import TaskNotFound -from airflow.models.dagbag import DagBag +from airflow.models.dagbag import BundleDagBag, DagBag from airflow.sdk.execution_time.comms import ( ConnectionResult, DeleteVariable, @@ -191,11 +190,6 @@ def _parse_file_entrypoint(): task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") - # Put bundle root on sys.path if needed. This allows the dag bundle to add - # code in util modules to be shared between files within the same bundle. - if (bundle_root := os.fspath(msg.bundle_path)) not in sys.path: - sys.path.append(bundle_root) - result = _parse_file(msg, log) if result is not None: comms_decoder.send(result) @@ -204,10 +198,9 @@ def _parse_file_entrypoint(): def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: # TODO: Set known_pool names on DagBag! - bag = DagBag( + bag = BundleDagBag( dag_folder=msg.file, bundle_path=msg.bundle_path, - include_examples=False, load_op_links=False, ) if msg.callback_requests: diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 5ee277d949d06..4576819a7dece 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -650,6 +650,40 @@ def dagbag_report(self): return report +class BundleDagBag(DagBag): + """ + Bundle-aware DagBag that permanently modifies sys.path. + + This class adds the bundle_path to sys.path permanently to allow DAG files + to import modules from their bundle directory. No cleanup is performed. + + WARNING: Only use for one-off usages like CLI commands. Using this in long-running + processes will cause sys.path to accumulate entries. + + Same parameters as DagBag, but bundle_path is required and examples are not loaded. + """ + + def __init__(self, *args, bundle_path: Path | None = None, **kwargs): + if not bundle_path: + raise ValueError("bundle_path is required for BundleDagBag") + + if str(bundle_path) not in sys.path: + sys.path.append(str(bundle_path)) + + # Warn if user explicitly set include_examples=True, since bundles never contain examples + if kwargs.get("include_examples") is True: + warnings.warn( + "include_examples=True is ignored for BundleDagBag. " + "Bundles do not contain example DAGs, so include_examples is always False.", + UserWarning, + stacklevel=2, + ) + + kwargs["bundle_path"] = bundle_path + kwargs["include_examples"] = False + super().__init__(*args, **kwargs) + + @provide_session def sync_bag_to_db( dagbag: DagBag, diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index b6423c5af3a34..66fe0ac47f72e 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -271,14 +271,15 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N find the correct path (assuming it's a file) and failing that, use the configured dags folder. """ - from airflow.models.dagbag import DagBag, sync_bag_to_db + from airflow.models.dagbag import BundleDagBag, sync_bag_to_db manager = DagBundlesManager() for bundle_name in bundle_names or (): bundle = manager.get_bundle(bundle_name) with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( - dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False + dagbag = BundleDagBag( + dag_folder=dagfile_path or bundle.path, + bundle_path=bundle.path, ) if dag := dagbag.dags.get(dag_id): return dag @@ -287,9 +288,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N for bundle in manager.get_all_dag_bundles(): bundle.initialize() with _airflow_parsing_context_manager(dag_id=dag_id): - dagbag = DagBag( - dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path, include_examples=False - ) + dagbag = BundleDagBag(dag_folder=dagfile_path or bundle.path, bundle_path=bundle.path) sync_bag_to_db(dagbag, bundle.name, bundle.version) if dag := dagbag.dags.get(dag_id): return dag @@ -315,7 +314,7 @@ def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, from_db: bool = False): """Return DAG(s) matching a given regex or dag_id.""" - from airflow.models import DagBag + from airflow.models.dagbag import BundleDagBag bundle_names = bundle_names or [] @@ -325,7 +324,7 @@ def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, fr return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)] def _find_dag(bundle): - dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path) + dagbag = BundleDagBag(dag_folder=bundle.path, bundle_path=bundle.path) matched_dags = [dag for dag in dagbag.dags.values() if re.search(dag_id, dag.dag_id)] return matched_dags diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index f6cae39b72083..517fdee2a51bc 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -759,7 +759,7 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag, stdout_capture): mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])]) assert "SOURCE" in output - @mock.patch("airflow.models.dagbag.DagBag") + @mock.patch("airflow.models.dagbag.BundleDagBag") def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using bundle name.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -783,10 +783,9 @@ def test_dag_test_with_bundle_name(self, mock_dagbag, configure_dag_bundles): mock_dagbag.assert_called_once_with( bundle_path=TEST_DAGS_FOLDER, dag_folder=TEST_DAGS_FOLDER, - include_examples=False, ) - @mock.patch("airflow.models.dagbag.DagBag") + @mock.patch("airflow.models.dagbag.BundleDagBag") def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using dagfile path.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -804,10 +803,9 @@ def test_dag_test_with_dagfile_path(self, mock_dagbag, configure_dag_bundles): mock_dagbag.assert_called_once_with( bundle_path=TEST_DAGS_FOLDER, dag_folder=str(dag_file), - include_examples=False, ) - @mock.patch("airflow.models.dagbag.DagBag") + @mock.patch("airflow.models.dagbag.BundleDagBag") def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure_dag_bundles): """Test that DAG can be tested using both bundle name and dagfile path.""" mock_dagbag.return_value.get_dag.return_value.test.return_value = DagRun( @@ -835,7 +833,6 @@ def test_dag_test_with_both_bundle_and_dagfile_path(self, mock_dagbag, configure mock_dagbag.assert_called_once_with( bundle_path=TEST_DAGS_FOLDER, dag_folder=str(dag_file), - include_examples=False, ) @mock.patch("airflow.models.dagrun.get_or_create_dagrun") diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index e9c77912e9f0b..c511e2a4df40c 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -54,7 +54,7 @@ get_next_data_interval, get_run_data_interval, ) -from airflow.models.dagbag import DBDagBag +from airflow.models.dagbag import BundleDagBag, DBDagBag from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel @@ -2147,7 +2147,7 @@ def test_relative_fileloc(self, session, testing_dag_bundle): rel_path = "test_assets.py" bundle_path = TEST_DAGS_FOLDER file_path = bundle_path / rel_path - bag = DagBag(dag_folder=file_path, bundle_path=bundle_path) + bag = BundleDagBag(dag_folder=file_path, bundle_path=bundle_path) dag = bag.get_dag("dag_with_skip_task") diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 5c064c973ac18..5e4639750c795 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -38,7 +38,7 @@ from airflow.exceptions import UnknownExecutorException from airflow.executors.executor_loader import ExecutorLoader from airflow.models.dag import DagModel -from airflow.models.dagbag import DagBag, _capture_with_reraise, _validate_executor_fields +from airflow.models.dagbag import BundleDagBag, DagBag, _capture_with_reraise, _validate_executor_fields from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk import DAG, BaseOperator @@ -921,3 +921,96 @@ def test_capture_warnings_with_error_filters(self): self.raise_warnings() assert len(cw) == 1 assert len(records) == 1 + + +class TestBundlePathSysPath: + """Tests for bundle_path sys.path handling in BundleDagBag.""" + + def test_bundle_path_added_to_syspath(self, tmp_path): + """Test that BundleDagBag adds bundle_path to sys.path when provided.""" + util_file = tmp_path / "bundle_util.py" + util_file.write_text('def get_message(): return "Hello from bundle!"') + + dag_file = tmp_path / "test_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + import sys + import bundle_util + + with DAG('test_import', description=f"DAG with sys.path: {sys.path}"): + EmptyOperator(task_id="mytask") + """ + ) + ) + + assert str(tmp_path) not in sys.path + + dagbag = BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path) + + # Check import was successful + assert len(dagbag.dags) == 1 + assert not dagbag.import_errors + + dag = dagbag.get_dag("test_import") + assert dag is not None + assert str(tmp_path) in dag.description # sys.path was enhanced during parse + + # Path remains in sys.path (no cleanup - intentional for ephemeral processes) + assert str(tmp_path) in sys.path + + # Cleanup for other tests + sys.path.remove(str(tmp_path)) + + def test_bundle_path_not_duplicated(self, tmp_path): + """Test that bundle_path is not added to sys.path if already present.""" + dag_file = tmp_path / "simple_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + with DAG("simple_dag"): + EmptyOperator(task_id="mytask") + """ + ) + ) + + # Pre-add the path + sys.path.append(str(tmp_path)) + count_before = sys.path.count(str(tmp_path)) + + BundleDagBag(dag_folder=str(dag_file), bundle_path=tmp_path) + + # Should not add duplicate + assert sys.path.count(str(tmp_path)) == count_before + + # Cleanup for other tests + sys.path.remove(str(tmp_path)) + + def test_dagbag_no_bundle_path_no_syspath_modification(self, tmp_path): + """Test that no sys.path modification occurs when DagBag is used without bundle_path.""" + dag_file = tmp_path / "simple_dag.py" + dag_file.write_text( + textwrap.dedent( + """\ + from airflow.sdk import DAG + from airflow.operators.empty import EmptyOperator + + import sys + + with DAG("simple_dag", description=f"DAG with sys.path: {sys.path}") as dag: + EmptyOperator(task_id="mytask") + """ + ) + ) + syspath_before = deepcopy(sys.path) + dagbag = DagBag(dag_folder=str(dag_file), include_examples=False) + dag = dagbag.get_dag("simple_dag") + + assert str(tmp_path) not in dag.description + assert sys.path == syspath_before diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index b51012d64a0fd..e8d954e880a1a 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1209,7 +1209,7 @@ def test( version = DagVersion.get_version(self.dag_id) if not version: from airflow.dag_processing.bundles.manager import DagBundlesManager - from airflow.models.dagbag import DagBag, sync_bag_to_db + from airflow.models.dagbag import BundleDagBag, sync_bag_to_db from airflow.sdk.definitions._internal.dag_parsing_context import ( _airflow_parsing_context_manager, ) @@ -1223,9 +1223,7 @@ def test( if not bundle.is_initialized: bundle.initialize() with _airflow_parsing_context_manager(dag_id=self.dag_id): - dagbag = DagBag( - dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False - ) + dagbag = BundleDagBag(dag_folder=bundle.path, bundle_path=bundle.path) sync_bag_to_db(dagbag, bundle.name, bundle.version) version = DagVersion.get_version(self.dag_id) if version: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 4ef67fb7496d7..a08f842acba4b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -637,9 +637,8 @@ def _maybe_reschedule_startup_failure( def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # TODO: Task-SDK: - # Using DagBag here is about 98% wrong, but it'll do for now - - from airflow.models.dagbag import DagBag + # Using BundleDagBag here is about 98% wrong, but it'll do for now + from airflow.models.dagbag import BundleDagBag bundle_info = what.bundle_info bundle_instance = DagBundlesManager().get_bundle( @@ -648,17 +647,12 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) bundle_instance.initialize() - # Put bundle root on sys.path if needed. This allows the dag bundle to add - # code in util modules to be shared between files within the same bundle. - if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path: - sys.path.append(bundle_root) - dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) - bag = DagBag( + bag = BundleDagBag( dag_folder=dag_absolute_path, - include_examples=False, safe_mode=False, load_op_links=False, + bundle_path=bundle_instance.path, ) if TYPE_CHECKING: assert what.ti.dag_id diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index b8f61a2bc50ef..fe8f1ad667d25 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -183,12 +183,60 @@ def test_parse(test_dags_dir: Path, make_ti_context): ): ti = parse(what, mock.Mock()) - assert ti.task - assert ti.task.dag assert isinstance(ti.task, BaseOperator) assert isinstance(ti.task.dag, DAG) +@mock.patch("airflow.models.dagbag.BundleDagBag") +def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): + """Test that checks that the BundleDagBag is constructed as expected during parsing""" + mock_bag_instance = mock.Mock() + mock_dagbag.return_value = mock_bag_instance + mock_dag = mock.Mock(spec=DAG) + mock_task = mock.Mock(spec=BaseOperator) + + mock_bag_instance.dags = {"super_basic": mock_dag} + mock_dag.task_dict = {"a": mock_task} + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="a", + dag_id="super_basic", + run_id="c", + try_number=1, + dag_version_id=uuid7(), + ), + dag_rel_path="super_basic.py", + bundle_info=BundleInfo(name="my-bundle", version=None), + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + with patch.dict( + os.environ, + { + "AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps( + [ + { + "name": "my-bundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"path": str(test_dags_dir), "refresh_interval": 1}, + } + ] + ), + }, + ): + parse(what, mock.Mock()) + + mock_dagbag.assert_called_once_with( + dag_folder=mock.ANY, + safe_mode=False, + load_op_links=False, + bundle_path=test_dags_dir, + ) + + @pytest.mark.parametrize( ("dag_id", "task_id", "expected_error"), (