Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions airflow/providers/cncf/kubernetes/decorators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
# under the License.
from __future__ import annotations

import base64
import inspect
import os
import pickle
import uuid
from shlex import quote
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Sequence

import dill
from kubernetes.client import models as k8s

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
Expand All @@ -37,21 +40,20 @@
from airflow.utils.context import Context

_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
_PYTHON_INPUT_ENV = "__PYTHON_INPUT"

_FILENAME_IN_CONTAINER = "/tmp/script.py"


def _generate_decode_command() -> str:
def _generate_decoded_command(env_var: str, file: str) -> str:
return (
f'python -c "import base64, os;'
rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
)


def _read_file_contents(filename):
with open(filename) as script_file:
return script_file.read()
def _read_file_contents(filename: str) -> str:
with open(filename, "rb") as script_file:
return base64.b64encode(script_file.read()).decode("utf-8")


class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
Expand All @@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
{"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
)

# since we won't mutate the arguments, we should just do the shallow copy
# Since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs: Sequence[str] = ("python_callable",)

def __init__(self, namespace: str = "default", **kwargs) -> None:
self.pickling_library = pickle
def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
self.pickling_library = dill if use_dill else pickle
super().__init__(
namespace=namespace,
name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
cmds=["bash"],
arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
cmds=["dummy-command"],
**kwargs,
)

Expand All @@ -82,11 +83,41 @@ def _get_python_source(self):
res = remove_task_decorator(res, "@task.kubernetes")
return res

def _generate_cmds(self) -> list[str]:
script_filename = "/tmp/script.py"
input_filename = "/tmp/script.in"
output_filename = "/airflow/xcom/return.json"

write_local_script_file_cmd = (
f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
)
write_local_input_file_cmd = (
f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
Comment on lines +92 to +95
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shlex.quote here isn't really necessary anymore since they don't capture user input but I left it in anyways. Thought that made it a bit more readable but I can remove if you prefer.

)
make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
return [
"bash",
"-cx",
" && ".join(
[
write_local_script_file_cmd,
write_local_input_file_cmd,
make_xcom_dir_cmd,
exec_python_cmd,
]
),
]

def execute(self, context: Context):
with TemporaryDirectory(prefix="venv") as tmp_dir:
script_filename = os.path.join(tmp_dir, "script.py")
py_source = self._get_python_source()
input_filename = os.path.join(tmp_dir, "script.in")

with open(input_filename, "wb") as file:
self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)

py_source = self._get_python_source()
jinja_context = {
"op_args": self.op_args,
"op_kwargs": self.op_kwargs,
Expand All @@ -100,7 +131,10 @@ def execute(self, context: Context):
self.env_vars = [
*self.env_vars,
k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
]

self.cmds = self._generate_cmds()
return super().execute(context)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
under the License.
-#}

import json
import {{ pickling_library }}
import sys

Expand All @@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}}
# Script
{{ python_callable_source }}
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])

# Write output
with open(sys.argv[2], "w") as file:
if res is not None:
file.write(json.dumps(res))
82 changes: 76 additions & 6 deletions tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import base64
import pickle
from unittest import mock

import pytest
Expand All @@ -29,6 +31,8 @@
POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"

XCOM_IMAGE = "XCOM_IMAGE"


@pytest.fixture(autouse=True)
def mock_create_pod() -> mock.Mock:
Expand All @@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock:
return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()


@pytest.fixture(autouse=True)
def await_xcom_sidecar_container_start() -> mock.Mock:
return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()


@pytest.fixture(autouse=True)
def extract_xcom() -> mock.Mock:
f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
f.return_value = '{"key1": "value1", "key2": "value2"}'
return f


@pytest.fixture(autouse=True)
def mock_await_pod_completion() -> mock.Mock:
f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
Expand Down Expand Up @@ -81,11 +97,65 @@ def f():

containers = mock_create_pod.call_args[1]["pod"].spec.containers
assert len(containers) == 1
assert containers[0].command == ["bash"]
assert containers[0].command[0] == "bash"
assert len(containers[0].args) == 0
assert containers[0].env[0].name == "__PYTHON_SCRIPT"
assert containers[0].env[0].value
assert containers[0].env[1].name == "__PYTHON_INPUT"

# Ensure we pass input through a b64 encoded env var
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
assert decoded_input == {"args": [], "kwargs": {}}


def test_kubernetes_with_input_output(
dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
) -> None:
with dag_maker(session=session) as dag:

@task.kubernetes(
image="python:3.10-slim-buster",
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
def f(arg1, arg2, kwarg1=None, kwarg2=None):
return {"key1": "value1", "key2": "value2"}

f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")

dr = dag_maker.create_dagrun()
(ti,) = dr.task_instances

mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE

dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))

mock_hook.assert_called_once_with(
conn_id=None,
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
assert mock_create_pod.call_count == 1
assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1

containers = mock_create_pod.call_args[1]["pod"].spec.containers

# First container is Python script
assert len(containers) == 2
assert containers[0].command[0] == "bash"
assert len(containers[0].args) == 0

assert containers[0].env[0].name == "__PYTHON_SCRIPT"
assert containers[0].env[0].value
assert containers[0].env[1].name == "__PYTHON_INPUT"
assert containers[0].env[1].value

assert len(containers[0].args) == 2
assert containers[0].args[0] == "-cx"
assert containers[0].args[1].endswith("/tmp/script.py")
# Ensure we pass input through a b64 encoded env var
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}}

assert containers[0].env[-1].name == "__PYTHON_SCRIPT"
assert containers[0].env[-1].value
# Second container is xcom image
assert containers[1].image == XCOM_IMAGE
assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"