Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
)

env_vars = dict(os.environ) if self.inherit_env else {}
if fd := os.getenv("__AIRFLOW_SUPERVISOR_FD"):
env_vars["__AIRFLOW_SUPERVISOR_FD"] = fd
if self.env_vars:
env_vars.update(self.env_vars)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _execute_in_subprocess(cmd: list[str], cwd: str | None = None, env: dict[str
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=0,
close_fds=True,
close_fds=False,
cwd=cwd,
env=env,
) as proc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ if sys.version_info >= (3,6):
pass
{% endif %}

try:
from airflow.sdk.execution_time import task_runner
except ModuleNotFoundError:
pass
else:
{#-
We are in an Airflow 3.x environment, try and set up supervisor comms so
virtualenv can access Vars/Conn/XCom/etc that normal tasks can

We don't use the walrus operator (`:=`) below as it is possible people can
be using this on pre-3.8 versions of python, and while Airflow doesn't
support them, it's easy to not break it not using that operator here.
#}
reinit_supervisor_comms = getattr(task_runner, "reinit_supervisor_comms", None)
if reinit_supervisor_comms:
reinit_supervisor_comms()

# Script
{{ python_callable_source }}

Expand All @@ -49,12 +66,10 @@ if sys.version_info >= (3,6):
import types

{{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")

{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}

sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}

{% endif%}
{%- endif -%}

{% if op_args or op_kwargs %}
with open(sys.argv[1], "rb") as file:
Expand Down
22 changes: 12 additions & 10 deletions providers/standard/tests/unit/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def run_as_operator(self, fn, **kwargs):
def run_as_task(self, fn, return_ti=False, **kwargs):
"""Create TaskInstance and run it."""
ti = self.create_ti(fn, **kwargs)
assert ti.task is not None
ti.run()
if return_ti:
return ti
Expand Down Expand Up @@ -976,15 +977,16 @@ def test_return_none(self):
def f():
return None

task = self.run_as_task(f)
assert task.execute_callable() is None
ti = self.run_as_task(f, return_ti=True)
assert ti.xcom_pull() is None

def test_return_false(self):
def f():
return False

task = self.run_as_task(f)
assert task.execute_callable() is False
ti = self.run_as_task(f, return_ti=True)

assert ti.xcom_pull() is False

def test_lambda(self):
with pytest.raises(
Expand Down Expand Up @@ -1149,8 +1151,8 @@ def f():

return os.environ["MY_ENV_VAR"]

task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"})
assert task.execute_callable() == "ABCDE"
ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"}, return_ti=True)
assert ti.xcom_pull() == "ABCDE"

def test_environment_variables_with_inherit_env_true(self, monkeypatch):
monkeypatch.setenv("MY_ENV_VAR", "QWERT")
Expand All @@ -1160,8 +1162,8 @@ def f():

return os.environ["MY_ENV_VAR"]

task = self.run_as_task(f, inherit_env=True)
assert task.execute_callable() == "QWERT"
ti = self.run_as_task(f, inherit_env=True, return_ti=True)
assert ti.xcom_pull() == "QWERT"

def test_environment_variables_with_inherit_env_false(self, monkeypatch):
monkeypatch.setenv("MY_ENV_VAR", "TYUIO")
Expand All @@ -1182,8 +1184,8 @@ def f():

return os.environ["MY_ENV_VAR"]

task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True)
assert task.execute_callable() == "EFGHI"
ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True, return_ti=True)
assert ti.xcom_pull() == "EFGHI"


venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
Expand Down
Loading