Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make current working directory as templated field in BashOperator #37968

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions airflow/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ class BashOperator(BaseOperator):
:param skip_on_exit_code: If task exits with this exit code, leave the task
in ``skipped`` state (default: 99). If set to ``None``, any non-zero
exit code will be treated as a failure.
:param cwd: Working directory to execute the command in.
:param cwd: Working directory to execute the command in (templated).
If None (default), the command is run in a temporary directory.
To use current DAG folder as the working directory,
you might set template ``{{ dag_run.dag.folder }}``.

Airflow will evaluate the exit code of the Bash command. In general, a non-zero exit code will result in
task failure and zero will result in task success.
Expand Down Expand Up @@ -130,7 +132,7 @@ class BashOperator(BaseOperator):

"""

template_fields: Sequence[str] = ("bash_command", "env")
template_fields: Sequence[str] = ("bash_command", "env", "cwd")
template_fields_renderers = {"bash_command": "bash", "env": "json"}
template_ext: Sequence[str] = (".sh", ".bash")
ui_color = "#f0ede4"
Expand Down
19 changes: 12 additions & 7 deletions tests/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,11 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da
session.add(rtif)
session.flush()

assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(
ti=ti, session=session
)
assert {
"bash_command": expected_rendered_field,
"env": None,
"cwd": None,
} == RTIF.get_templated_fields(ti=ti, session=session)
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
Expand Down Expand Up @@ -261,7 +263,7 @@ def test_write(self, dag_maker):
)
.first()
)
assert ("test_write", "test", {"bash_command": "echo test_val", "env": None}) == result
assert ("test_write", "test", {"bash_command": "echo test_val", "env": None, "cwd": None}) == result

# Test that overwrite saves new values to the DB
Variable.delete("test_key")
Expand All @@ -287,7 +289,7 @@ def test_write(self, dag_maker):
assert (
"test_write",
"test",
{"bash_command": "echo test_val_updated", "env": None},
{"bash_command": "echo test_val_updated", "env": None, "cwd": None},
) == result_updated

@mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
Expand All @@ -301,8 +303,10 @@ def test_redact(self, redact, dag_maker):
)
dr = dag_maker.create_dagrun()
redact.side_effect = [
"val 1",
"val 2",
# Order depends on order in Operator template_fields
"val 1", # bash_command
"val 2", # env
"val 3", # cwd
]

ti = dr.task_instances[0]
Expand All @@ -311,4 +315,5 @@ def test_redact(self, redact, dag_maker):
assert rtif.rendered_fields == {
"bash_command": "val 1",
"env": "val 2",
"cwd": "val 3",
}
6 changes: 5 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,11 @@ def test_retry_handling(self, dag_maker):
"""
Test that task retries are handled properly
"""
expected_rendered_ti_fields = {"env": None, "bash_command": "echo test_retry_handling; exit 1"}
expected_rendered_ti_fields = {
"env": None,
"bash_command": "echo test_retry_handling; exit 1",
"cwd": None,
}

with dag_maker(dag_id="test_retry_handling") as dag:
task = BashOperator(
Expand Down
20 changes: 20 additions & 0 deletions tests/operators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import signal
from datetime import datetime, timedelta
from pathlib import Path
from time import sleep
from unittest import mock

Expand Down Expand Up @@ -244,3 +245,22 @@ def test_bash_operator_kill(self, dag_maker):
os.kill(proc.pid, signal.SIGTERM)
assert False, "BashOperator's subprocess still running after stopping on timeout!"
break

@pytest.mark.db_test
def test_templated_fields(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
BashOperator,
# Templated fields
bash_command='echo "{{ dag_run.dag_id }}"',
env={"FOO": "{{ ds }}"},
cwd="{{ dag_run.dag.folder }}",
# Other parameters
dag_id="test_templated_fields_dag",
task_id="test_templated_fields_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: BashOperator = ti.task
assert task.bash_command == 'echo "test_templated_fields_dag"'
assert task.env == {"FOO": "2024-02-01"}
assert task.cwd == Path(__file__).absolute().parent.as_posix()
6 changes: 3 additions & 3 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"bash_command": "echo {{ task.task_id }}",
"_task_type": "BashOperator",
Expand Down Expand Up @@ -2150,7 +2150,7 @@ def test_operator_expand_serde():
},
"task_id": "a",
"operator_extra_links": [],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_ext": [".sh", ".bash"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"ui_color": "#f0ede4",
Expand All @@ -2168,7 +2168,7 @@ def test_operator_expand_serde():
"downstream_task_ids": [],
"task_id": "a",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
Expand Down
Loading