diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py index f68927c676433..844d5300f56d9 100644 --- a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py @@ -16,14 +16,17 @@ # under the License. from __future__ import annotations +import base64 import inspect import os import pickle import uuid +from shlex import quote from tempfile import TemporaryDirectory from textwrap import dedent from typing import TYPE_CHECKING, Callable, Sequence +import dill from kubernetes.client import models as k8s from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory @@ -37,21 +40,20 @@ from airflow.utils.context import Context _PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT" +_PYTHON_INPUT_ENV = "__PYTHON_INPUT" -_FILENAME_IN_CONTAINER = "/tmp/script.py" - -def _generate_decode_command() -> str: +def _generate_decoded_command(env_var: str, file: str) -> str: return ( f'python -c "import base64, os;' - rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];" - rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"' + rf"x = base64.b64decode(os.environ[\"{env_var}\"]);" + rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"' ) -def _read_file_contents(filename): - with open(filename) as script_file: - return script_file.read() +def _read_file_contents(filename: str) -> str: + with open(filename, "rb") as script_file: + return base64.b64encode(script_file.read()).decode("utf-8") class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator): @@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator): {"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"} ) - # since we won't mutate the arguments, we should just do the shallow copy + # Since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). shallow_copy_attrs: Sequence[str] = ("python_callable",) - def __init__(self, namespace: str = "default", **kwargs) -> None: - self.pickling_library = pickle + def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None: + self.pickling_library = dill if use_dill else pickle super().__init__( namespace=namespace, name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"), - cmds=["bash"], - arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"], + cmds=["dummy-command"], **kwargs, ) @@ -82,11 +83,41 @@ def _get_python_source(self): res = remove_task_decorator(res, "@task.kubernetes") return res + def _generate_cmds(self) -> list[str]: + script_filename = "/tmp/script.py" + input_filename = "/tmp/script.in" + output_filename = "/airflow/xcom/return.json" + + write_local_script_file_cmd = ( + f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}" + ) + write_local_input_file_cmd = ( + f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}" + ) + make_xcom_dir_cmd = "mkdir -p /airflow/xcom" + exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}" + return [ + "bash", + "-cx", + " && ".join( + [ + write_local_script_file_cmd, + write_local_input_file_cmd, + make_xcom_dir_cmd, + exec_python_cmd, + ] + ), + ] + def execute(self, context: Context): with TemporaryDirectory(prefix="venv") as tmp_dir: script_filename = os.path.join(tmp_dir, "script.py") - py_source = self._get_python_source() + input_filename = os.path.join(tmp_dir, "script.in") + + with open(input_filename, "wb") as file: + self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) + py_source = self._get_python_source() jinja_context = { "op_args": self.op_args, "op_kwargs": self.op_kwargs, @@ -100,7 +131,10 @@ def execute(self, context: Context): self.env_vars = [ *self.env_vars, k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)), + k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)), ] + + self.cmds = self._generate_cmds() return super().execute(context) diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 index c961f10de4e5c..4042c07fc464b 100644 --- a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 +++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 @@ -17,6 +17,7 @@ under the License. -#} +import json import {{ pickling_library }} import sys @@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}} # Script {{ python_callable_source }} res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) + +# Write output +with open(sys.argv[2], "w") as file: + if res is not None: + file.write(json.dumps(res)) diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py index 46b087688cff9..584df8de038a3 100644 --- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import base64 +import pickle from unittest import mock import pytest @@ -29,6 +31,8 @@ POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" +XCOM_IMAGE = "XCOM_IMAGE" + @pytest.fixture(autouse=True) def mock_create_pod() -> mock.Mock: @@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock: return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start() +@pytest.fixture(autouse=True) +def await_xcom_sidecar_container_start() -> mock.Mock: + return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start() + + +@pytest.fixture(autouse=True) +def extract_xcom() -> mock.Mock: + f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start() + f.return_value = '{"key1": "value1", "key2": "value2"}' + return f + + @pytest.fixture(autouse=True) def mock_await_pod_completion() -> mock.Mock: f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start() @@ -81,11 +97,65 @@ def f(): containers = mock_create_pod.call_args[1]["pod"].spec.containers assert len(containers) == 1 - assert containers[0].command == ["bash"] + assert containers[0].command[0] == "bash" + assert len(containers[0].args) == 0 + assert containers[0].env[0].name == "__PYTHON_SCRIPT" + assert containers[0].env[0].value + assert containers[0].env[1].name == "__PYTHON_INPUT" + + # Ensure we pass input through a b64 encoded env var + decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value)) + assert decoded_input == {"args": [], "kwargs": {}} + + +def test_kubernetes_with_input_output( + dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock +) -> None: + with dag_maker(session=session) as dag: + + @task.kubernetes( + image="python:3.10-slim-buster", + in_cluster=False, + cluster_context="default", + config_file="/tmp/fake_file", + ) + def f(arg1, arg2, kwarg1=None, kwarg2=None): + return {"key1": "value1", "key2": "value2"} + + f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1") + + dr = dag_maker.create_dagrun() + (ti,) = dr.task_instances + + mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE + + dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session)) + + mock_hook.assert_called_once_with( + conn_id=None, + in_cluster=False, + cluster_context="default", + config_file="/tmp/fake_file", + ) + assert mock_create_pod.call_count == 1 + assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1 + + containers = mock_create_pod.call_args[1]["pod"].spec.containers + + # First container is Python script + assert len(containers) == 2 + assert containers[0].command[0] == "bash" + assert len(containers[0].args) == 0 + + assert containers[0].env[0].name == "__PYTHON_SCRIPT" + assert containers[0].env[0].value + assert containers[0].env[1].name == "__PYTHON_INPUT" + assert containers[0].env[1].value - assert len(containers[0].args) == 2 - assert containers[0].args[0] == "-cx" - assert containers[0].args[1].endswith("/tmp/script.py") + # Ensure we pass input through a b64 encoded env var + decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value)) + assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}} - assert containers[0].env[-1].name == "__PYTHON_SCRIPT" - assert containers[0].env[-1].value + # Second container is xcom image + assert containers[1].image == XCOM_IMAGE + assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"