Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions providers/docker/src/airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class DockerOperator(BaseOperator):
``AIRFLOW_TMP_DIR`` inside the container.
:param user: Default user inside the docker container.
:param mounts: List of volumes to mount into the container. Each item should
be a :py:class:`docker.types.Mount` instance.
be a :py:class:`docker.types.Mount` instance. (templated)
:param entrypoint: Overwrite the default ENTRYPOINT of the image
:param working_dir: Working directory to
set on the container (equivalent to the -w switch the docker client)
Expand Down Expand Up @@ -198,7 +198,14 @@ class DockerOperator(BaseOperator):
# - docs/apache-airflow-providers-docker/decorators/docker.rst
# - airflow/decorators/__init__.pyi (by a separate PR)

template_fields: Sequence[str] = ("image", "command", "environment", "env_file", "container_name")
template_fields: Sequence[str] = (
"image",
"command",
"environment",
"env_file",
"container_name",
"mounts",
)
template_fields_renderers = {"env_file": "yaml"}
template_ext: Sequence[str] = (
".sh",
Expand Down Expand Up @@ -291,6 +298,8 @@ def __init__(
self.tmp_dir = tmp_dir
self.user = user
self.mounts = mounts or []
for mount in self.mounts:
mount.template_fields = ("Source", "Target", "Type")
self.entrypoint = entrypoint
self.working_dir = working_dir
self.xcom_all = xcom_all
Expand Down
10 changes: 9 additions & 1 deletion providers/docker/tests/unit/docker/decorators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,14 @@ def f(num_results):

@pytest.mark.db_test
def test_basic_docker_operator_with_template_fields(self, dag_maker):
@task.docker(image="python:3.9-slim", container_name="python_{{dag_run.dag_id}}", auto_remove="force")
from docker.types import Mount

@task.docker(
image="python:3.9-slim",
container_name="python_{{dag_run.dag_id}}",
auto_remove="force",
mounts=[Mount(source="workspace", target="/{{task_instance.run_id}}")],
)
def f():
raise RuntimeError("Should not executed")

Expand All @@ -93,6 +100,7 @@ def f():
ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
rendered = ti.render_templates()
assert rendered.container_name == f"python_{dr.dag_id}"
assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"

@pytest.mark.db_test
def test_basic_docker_operator_multiple_output(self, dag_maker, session):
Expand Down
20 changes: 20 additions & 0 deletions providers/docker/tests/unit/docker/operators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from docker.types import DeviceRequest, LogConfig, Mount, Ulimit

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import TaskInstance
from airflow.providers.docker.exceptions import DockerContainerFailedException
from airflow.providers.docker.operators.docker import DockerOperator, fetch_logs

Expand Down Expand Up @@ -794,3 +795,22 @@ def test_labels(self, labels: dict[str, str] | list[str]):
self.client_mock.create_container.assert_called_once()
assert "labels" in self.client_mock.create_container.call_args.kwargs
assert labels == self.client_mock.create_container.call_args.kwargs["labels"]

@pytest.mark.db_test
def test_basic_docker_operator_with_template_fields(self, dag_maker):
from docker.types import Mount

with dag_maker():
operator = DockerOperator(
task_id="test",
image="test",
container_name="python_{{dag_run.dag_id}}",
mounts=[Mount(source="workspace", target="/{{task_instance.run_id}}")],
)
operator.execute({})

dr = dag_maker.create_dagrun()
ti = TaskInstance(task=operator, run_id=dr.run_id)
rendered = ti.render_templates()
assert rendered.container_name == f"python_{dr.dag_id}"
assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"