Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,28 @@ def execute(self, context: Context) -> Any:
serializable_keys = set(self._iter_serializable_context_keys())
new = {k: v for k, v in context.items() if k in serializable_keys}
serializable_context = cast("Context", new)
# Store bundle_path for subprocess execution
self._bundle_path = self._get_bundle_path_from_context(context)
return super().execute(context=serializable_context)

def _get_bundle_path_from_context(self, context: Context) -> str | None:
"""
Extract bundle_path from the task instance's bundle_instance.

:param context: The task execution context
:return: Path to the bundle root directory, or None if not in a bundle
"""
if not AIRFLOW_V_3_0_PLUS:
return None

# In Airflow 3.x, the RuntimeTaskInstance has a bundle_instance attribute
# that contains the bundle information including its path
ti = context["ti"]
if bundle_instance := getattr(ti, "bundle_instance", None):
return bundle_instance.path

return None

def get_python_source(self):
"""Return the source of self.python_callable."""
return textwrap.dedent(inspect.getsource(self.python_callable))
Expand Down Expand Up @@ -565,6 +585,16 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
if self.env_vars:
env_vars.update(self.env_vars)

# Add bundle_path to PYTHONPATH for subprocess to import Dag bundle modules
if self._bundle_path:
bundle_path = self._bundle_path
existing_pythonpath = env_vars.get("PYTHONPATH", "")
if existing_pythonpath:
# Append bundle_path after existing PYTHONPATH
env_vars["PYTHONPATH"] = f"{existing_pythonpath}{os.pathsep}{bundle_path}"
else:
env_vars["PYTHONPATH"] = bundle_path

try:
cmd: list[str] = [
os.fspath(python_path),
Expand Down
84 changes: 84 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,90 @@ def poke(self, context):
virtualenv_string_args: list[str] = []


@pytest.mark.execution_timeout(120)
@pytest.mark.parametrize(
("opcls", "test_class_ref"),
[
pytest.param(
PythonVirtualenvOperator,
lambda: TestPythonVirtualenvOperator,
id="PythonVirtualenvOperator",
),
pytest.param(
ExternalPythonOperator,
lambda: TestExternalPythonOperator,
id="ExternalPythonOperator",
),
],
)
class TestDagBundleImportInSubprocess(BasePythonTest):
"""
Test Dag bundle imports for subprocess-based Python operators.

This test ensures that callables running in subprocesses can import modules
from their Dag bundle by verifying PYTHONPATH is correctly set (Airflow 3.x+).
"""

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Dag Bundle import fix is for Airflow 3.x+")
@mock.patch("airflow.providers.standard.operators.python._execute_in_subprocess")
def test_dag_bundle_import_in_subprocess(
self, mock_execute_subprocess, dag_maker, opcls, test_class_ref, tmp_path
):
"""
Tests that a callable in a subprocess can import modules from its
own Dag bundle (Airflow 3.x+).
"""

def _callable_that_imports_from_bundle():
from test_bundle_pkg.lib.helper import get_message

return get_message()

bundle_root = tmp_path

module_dir = bundle_root / "test_bundle_pkg"
lib_dir = module_dir / "lib"
lib_dir.mkdir(parents=True)

(module_dir / "__init__.py").touch()
(lib_dir / "__init__.py").touch()
(lib_dir / "helper.py").write_text("def get_message():\n return 'it works from bundle'")

# We need a real DAG to create a real TI context
with dag_maker(self.dag_id, serialized=True):
op = opcls(
task_id=self.task_id,
python_callable=_callable_that_imports_from_bundle,
**test_class_ref().default_kwargs(),
)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(self.task_id)

mock_bundle_instance = mock.Mock()
mock_bundle_instance.path = str(bundle_root)
ti.bundle_instance = mock_bundle_instance

context = ti.get_template_context()

# Mock subprocess execution to avoid testing-environment related issues
# on the ExternalPythonOperator (Socket operation on non-socket)
# Instead, we just check the env argument of _execute_in_subprocess
# if the bundle_path was added to PYTHONPATH

# Mock _read_result to avoid reading the non-existent output file
with mock.patch.object(op, "_read_result", return_value=None):
op.execute(context)

assert mock_execute_subprocess.called, "_execute_in_subprocess should have been called"
call_kwargs = mock_execute_subprocess.call_args.kwargs
env = call_kwargs.get("env")
assert "PYTHONPATH" in env, "PYTHONPATH should be in env"

pythonpath = env["PYTHONPATH"]
assert str(bundle_root) in pythonpath, f"Bundle path {bundle_root} should be in PYTHONPATH"


@pytest.mark.execution_timeout(120)
class BaseTestPythonVirtualenvOperator(BasePythonTest):
def test_template_fields(self):
Expand Down
Loading