diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 160617c2b789b..b597ef38cc5f1 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -562,6 +562,8 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): ) env_vars = dict(os.environ) if self.inherit_env else {} + if fd := os.getenv("__AIRFLOW_SUPERVISOR_FD"): + env_vars["__AIRFLOW_SUPERVISOR_FD"] = fd if self.env_vars: env_vars.update(self.env_vars) diff --git a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py index ee71f33a56056..891b3e0bce3a1 100644 --- a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py +++ b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py @@ -150,7 +150,7 @@ def _execute_in_subprocess(cmd: list[str], cwd: str | None = None, env: dict[str stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, - close_fds=True, + close_fds=False, cwd=cwd, env=env, ) as proc: diff --git a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 index cb4b738138f07..8cb3ace35aa93 100644 --- a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 +++ b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 @@ -40,6 +40,23 @@ if sys.version_info >= (3,6): pass {% endif %} +try: + from airflow.sdk.execution_time import task_runner +except ModuleNotFoundError: + pass +else: + {#- + We are in an Airflow 3.x environment, try and set up supervisor comms so + virtualenv can access Vars/Conn/XCom/etc that normal tasks can + + We don't use the walrus operator (`:=`) below as it is possible people can + be using this on pre-3.8 versions of python, and while Airflow doesn't + support them, it's easy to not break it not using that operator here. + #} + reinit_supervisor_comms = getattr(task_runner, "reinit_supervisor_comms", None) + if reinit_supervisor_comms: + reinit_supervisor_comms() + # Script {{ python_callable_source }} @@ -49,12 +66,10 @@ if sys.version_info >= (3,6): import types {{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}") - {{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }} - sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}} -{% endif%} +{%- endif -%} {% if op_args or op_kwargs %} with open(sys.argv[1], "rb") as file: diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 7526b5128f57a..9787b85a6c607 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -201,6 +201,7 @@ def run_as_operator(self, fn, **kwargs): def run_as_task(self, fn, return_ti=False, **kwargs): """Create TaskInstance and run it.""" ti = self.create_ti(fn, **kwargs) + assert ti.task is not None ti.run() if return_ti: return ti @@ -976,15 +977,16 @@ def test_return_none(self): def f(): return None - task = self.run_as_task(f) - assert task.execute_callable() is None + ti = self.run_as_task(f, return_ti=True) + assert ti.xcom_pull() is None def test_return_false(self): def f(): return False - task = self.run_as_task(f) - assert task.execute_callable() is False + ti = self.run_as_task(f, return_ti=True) + + assert ti.xcom_pull() is False def test_lambda(self): with pytest.raises( @@ -1149,8 +1151,8 @@ def f(): return os.environ["MY_ENV_VAR"] - task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"}) - assert task.execute_callable() == "ABCDE" + ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"}, return_ti=True) + assert ti.xcom_pull() == "ABCDE" def test_environment_variables_with_inherit_env_true(self, monkeypatch): monkeypatch.setenv("MY_ENV_VAR", "QWERT") @@ -1160,8 +1162,8 @@ def f(): return os.environ["MY_ENV_VAR"] - task = self.run_as_task(f, inherit_env=True) - assert task.execute_callable() == "QWERT" + ti = self.run_as_task(f, inherit_env=True, return_ti=True) + assert ti.xcom_pull() == "QWERT" def test_environment_variables_with_inherit_env_false(self, monkeypatch): monkeypatch.setenv("MY_ENV_VAR", "TYUIO") @@ -1182,8 +1184,8 @@ def f(): return os.environ["MY_ENV_VAR"] - task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True) - assert task.execute_callable() == "EFGHI" + ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True, return_ti=True) + assert ti.xcom_pull() == "EFGHI" venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")