Skip to content

Commit

Permalink
Save pod name to xcom for KubernetesPodOperator (#15755)
Browse files Browse the repository at this point in the history
* Save pod name to xcom for KubernetesPodOperator

* fix kubernetes test
  • Loading branch information
junnplus authored May 14, 2021
1 parent c493b4d commit 37d549b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,6 @@ def execute(self, context) -> Optional[str]:

label_selector = self._get_pod_identifying_label_string(labels)

self.namespace = self.pod.metadata.namespace

pod_list = client.list_namespaced_pod(self.namespace, label_selector=label_selector)

if len(pod_list.items) > 1 and self.reattach_on_restart:
Expand All @@ -367,6 +365,8 @@ def execute(self, context) -> Optional[str]:
if final_state != State.SUCCESS:
status = self.client.read_namespaced_pod(self.pod.metadata.name, self.namespace)
raise AirflowException(f'Pod {self.pod.metadata.name} returned a failure: {status}')
context['task_instance'].xcom_push(key='pod_name', value=self.pod.metadata.name)
context['task_instance'].xcom_push(key='pod_namespace', value=self.namespace)
return result
except AirflowException as ex:
raise AirflowException(f'Pod Launching failed: {ex}')
Expand Down
2 changes: 2 additions & 0 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ def create_context(task):
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=task, execution_date=execution_date)
task_instance.xcom_push = mock.Mock()
return {
"dag": dag,
"ts": execution_date.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
}


Expand Down
24 changes: 7 additions & 17 deletions kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def create_context(task):
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=task, execution_date=execution_date)
task_instance.xcom_push = mock.Mock()
return {
"dag": dag,
"ts": execution_date.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
}


Expand Down Expand Up @@ -116,18 +118,6 @@ def tearDown(self):
client = kube_client.get_kube_client(in_cluster=False)
client.delete_collection_namespaced_pod(namespace="default")

def create_context(self, task):
dag = DAG(dag_id="dag")
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=task, execution_date=execution_date)
return {
"dag": dag,
"ts": execution_date.isoformat(),
"task": task,
"ti": task_instance,
}

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
Expand All @@ -147,7 +137,7 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start
cluster_context='default',
)
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
context = create_context(k)
k.execute(context=context)
assert start_mock.call_args[0][0].spec.image_pull_secrets == [
k8s.V1LocalObjectReference(name=fake_pull_secrets)
Expand Down Expand Up @@ -212,7 +202,7 @@ def test_pod_resources(self):
do_xcom_push=False,
resources=resources,
)
context = self.create_context(k)
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['resources'] = {
Expand Down Expand Up @@ -268,7 +258,7 @@ def test_port(self):
do_xcom_push=False,
ports=[port],
)
context = self.create_context(k)
context = create_context(k)
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}]
Expand Down Expand Up @@ -479,7 +469,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
)
# THEN
mock_monitor.return_value = (State.SUCCESS, None)
context = self.create_context(k)
context = create_context(k)
k.execute(context)
assert mock_start.call_args[0][0].spec.containers[0].env_from == [
k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))
Expand Down Expand Up @@ -507,7 +497,7 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
)
# THEN
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
context = create_context(k)
k.execute(context)
assert start_mock.call_args[0][0].spec.containers[0].env_from == [
k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))
Expand Down
27 changes: 22 additions & 5 deletions tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tempfile import NamedTemporaryFile
from unittest import mock

import pendulum
import pytest
from kubernetes.client import ApiClient, models as k8s

Expand All @@ -29,6 +28,8 @@
from airflow.utils import timezone
from airflow.utils.state import State

DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0)


class TestKubernetesPodOperator(unittest.TestCase):
def setUp(self):
Expand All @@ -49,14 +50,13 @@ def setUp(self):
@staticmethod
def create_context(task):
dag = DAG(dag_id="dag")
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=task, execution_date=execution_date)
task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE)
return {
"dag": dag,
"ts": execution_date.isoformat(),
"ts": DEFAULT_DATE.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
}

def run_pod(self, operator) -> k8s.V1Pod:
Expand Down Expand Up @@ -605,3 +605,20 @@ def test_node_selector(self):
sanitized_pod = self.sanitize_for_serialization(pod)
assert isinstance(pod.spec.node_selector, dict)
assert sanitized_pod["spec"]["nodeSelector"] == node_selector

def test_push_xcom_pod_info(self):
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
name="test",
task_id="task",
in_cluster=False,
do_xcom_push=False,
)
pod = self.run_pod(k)
ti = TaskInstance(task=k, execution_date=DEFAULT_DATE)
pod_name = ti.xcom_pull(task_ids=k.task_id, key='pod_name')
pod_namespace = ti.xcom_pull(task_ids=k.task_id, key='pod_namespace')
assert pod_name and pod_name == pod.metadata.name
assert pod_namespace and pod_namespace == pod.metadata.namespace

0 comments on commit 37d549b

Please sign in to comment.