diff --git a/providers/docker/src/airflow/providers/docker/operators/docker.py b/providers/docker/src/airflow/providers/docker/operators/docker.py index 546c8736b13bf..bf6f7a463fad6 100644 --- a/providers/docker/src/airflow/providers/docker/operators/docker.py +++ b/providers/docker/src/airflow/providers/docker/operators/docker.py @@ -145,7 +145,7 @@ class DockerOperator(BaseOperator): ``AIRFLOW_TMP_DIR`` inside the container. :param user: Default user inside the docker container. :param mounts: List of volumes to mount into the container. Each item should - be a :py:class:`docker.types.Mount` instance. + be a :py:class:`docker.types.Mount` instance. (templated) :param entrypoint: Overwrite the default ENTRYPOINT of the image :param working_dir: Working directory to set on the container (equivalent to the -w switch the docker client) @@ -198,7 +198,14 @@ class DockerOperator(BaseOperator): # - docs/apache-airflow-providers-docker/decorators/docker.rst # - airflow/decorators/__init__.pyi (by a separate PR) - template_fields: Sequence[str] = ("image", "command", "environment", "env_file", "container_name") + template_fields: Sequence[str] = ( + "image", + "command", + "environment", + "env_file", + "container_name", + "mounts", + ) template_fields_renderers = {"env_file": "yaml"} template_ext: Sequence[str] = ( ".sh", @@ -291,6 +298,8 @@ def __init__( self.tmp_dir = tmp_dir self.user = user self.mounts = mounts or [] + for mount in self.mounts: + mount.template_fields = ("Source", "Target", "Type") self.entrypoint = entrypoint self.working_dir = working_dir self.xcom_all = xcom_all diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py b/providers/docker/tests/unit/docker/decorators/test_docker.py index 2009361d8832d..c2264a6c76340 100644 --- a/providers/docker/tests/unit/docker/decorators/test_docker.py +++ b/providers/docker/tests/unit/docker/decorators/test_docker.py @@ -82,7 +82,14 @@ def f(num_results): @pytest.mark.db_test def test_basic_docker_operator_with_template_fields(self, dag_maker): - @task.docker(image="python:3.9-slim", container_name="python_{{dag_run.dag_id}}", auto_remove="force") + from docker.types import Mount + + @task.docker( + image="python:3.9-slim", + container_name="python_{{dag_run.dag_id}}", + auto_remove="force", + mounts=[Mount(source="workspace", target="/{{task_instance.run_id}}")], + ) def f(): raise RuntimeError("Should not executed") @@ -93,6 +100,7 @@ def f(): ti = TaskInstance(task=ret.operator, run_id=dr.run_id) rendered = ti.render_templates() assert rendered.container_name == f"python_{dr.dag_id}" + assert rendered.mounts[0]["Target"] == f"/{ti.run_id}" @pytest.mark.db_test def test_basic_docker_operator_multiple_output(self, dag_maker, session): diff --git a/providers/docker/tests/unit/docker/operators/test_docker.py b/providers/docker/tests/unit/docker/operators/test_docker.py index d3e5bf6328238..a8ab46e89af87 100644 --- a/providers/docker/tests/unit/docker/operators/test_docker.py +++ b/providers/docker/tests/unit/docker/operators/test_docker.py @@ -27,6 +27,7 @@ from docker.types import DeviceRequest, LogConfig, Mount, Ulimit from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.models import TaskInstance from airflow.providers.docker.exceptions import DockerContainerFailedException from airflow.providers.docker.operators.docker import DockerOperator, fetch_logs @@ -794,3 +795,22 @@ def test_labels(self, labels: dict[str, str] | list[str]): self.client_mock.create_container.assert_called_once() assert "labels" in self.client_mock.create_container.call_args.kwargs assert labels == self.client_mock.create_container.call_args.kwargs["labels"] + + @pytest.mark.db_test + def test_basic_docker_operator_with_template_fields(self, dag_maker): + from docker.types import Mount + + with dag_maker(): + operator = DockerOperator( + task_id="test", + image="test", + container_name="python_{{dag_run.dag_id}}", + mounts=[Mount(source="workspace", target="/{{task_instance.run_id}}")], + ) + operator.execute({}) + + dr = dag_maker.create_dagrun() + ti = TaskInstance(task=operator, run_id=dr.run_id) + rendered = ti.render_templates() + assert rendered.container_name == f"python_{dr.dag_id}" + assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"