diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 2f4516cedfc64..eb513a91e43e4 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -37,7 +37,11 @@ from airflow.models import Connection from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import should_retry_creation -from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol +from airflow.providers.cncf.kubernetes.utils.pod_manager import ( + PodOperatorHookProtocol, + container_is_completed, + container_is_running, +) from airflow.utils import yaml if TYPE_CHECKING: @@ -834,3 +838,37 @@ async def wait_until_job_complete(self, name: str, namespace: str, poll_interval return job self.log.info("The job '%s' is incomplete. Sleeping for %i sec.", name, poll_interval) await asyncio.sleep(poll_interval) + + async def wait_until_container_complete( + self, name: str, namespace: str, container_name: str, poll_interval: float = 10 + ) -> None: + """ + Wait for the given container in the given pod to be completed. + + :param name: Name of Pod to fetch. + :param namespace: Namespace of the Pod. + :param container_name: name of the container within the pod to monitor + """ + while True: + pod = await self.get_pod(name=name, namespace=namespace) + if container_is_completed(pod=pod, container_name=container_name): + break + self.log.info("Waiting for container '%s' state to be completed", container_name) + await asyncio.sleep(poll_interval) + + async def wait_until_container_started( + self, name: str, namespace: str, container_name: str, poll_interval: float = 10 + ) -> None: + """ + Wait for the given container in the given pod to be started. + + :param name: Name of Pod to fetch. + :param namespace: Namespace of the Pod. + :param container_name: name of the container within the pod to monitor + """ + while True: + pod = await self.get_pod(name=name, namespace=namespace) + if container_is_running(pod=pod, container_name=container_name): + break + self.log.info("Waiting for container '%s' state to be running", container_name) + await asyncio.sleep(poll_interval) diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py index eb7c64614615f..91ff5fa40ec0b 100644 --- a/airflow/providers/cncf/kubernetes/operators/job.py +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -19,6 +19,7 @@ from __future__ import annotations import copy +import json import logging import os from functools import cached_property @@ -39,6 +40,7 @@ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger +from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, PodNotFoundException from airflow.utils import yaml from airflow.utils.context import Context @@ -135,7 +137,7 @@ def hook(self) -> KubernetesHook: return hook @cached_property - def client(self) -> BatchV1Api: + def job_client(self) -> BatchV1Api: return self.hook.batch_v1_client def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job: @@ -150,6 +152,11 @@ def execute(self, context: Context): "Deferrable mode is available only with parameter `wait_until_job_complete=True`. " "Please, set it up." ) + if (self.get_logs or self.do_xcom_push) and not self.wait_until_job_complete: + self.log.warning( + "Getting Logs and pushing to XCom are available only with parameter `wait_until_job_complete=True`. " + "Please, set it up." + ) self.job_request_obj = self.build_job_request_obj(context) self.job = self.create_job( # must set `self.job` for `on_kill` job_request_obj=self.job_request_obj @@ -159,16 +166,34 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) + if self.pod is None: + self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` + pod_request_obj=self.pod_request_obj, + context=context, + ) + if self.wait_until_job_complete and self.deferrable: self.execute_deferrable() return if self.wait_until_job_complete: + if self.do_xcom_push: + self.pod_manager.await_container_completion( + pod=self.pod, container_name=self.base_container_name + ) + self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) + xcom_result = self.extract_xcom(pod=self.pod) self.job = self.hook.wait_until_job_complete( job_name=self.job.metadata.name, namespace=self.job.metadata.namespace, job_poll_interval=self.job_poll_interval, ) + if self.get_logs: + self.pod_manager.fetch_requested_container_logs( + pod=self.pod, + containers=self.container_logs, + follow_logs=True, + ) ti.xcom_push(key="job", value=self.job.to_dict()) if self.wait_until_job_complete: @@ -176,17 +201,24 @@ def execute(self, context: Context): raise AirflowException( f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'" ) + if self.do_xcom_push: + return xcom_result def execute_deferrable(self): self.defer( trigger=KubernetesJobTrigger( job_name=self.job.metadata.name, # type: ignore[union-attr] job_namespace=self.job.metadata.namespace, # type: ignore[union-attr] + pod_name=self.pod.metadata.name, # type: ignore[union-attr] + pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr] + base_container_name=self.base_container_name, kubernetes_conn_id=self.kubernetes_conn_id, cluster_context=self.cluster_context, config_file=self.config_file, in_cluster=self.in_cluster, poll_interval=self.job_poll_interval, + get_logs=self.get_logs, + do_xcom_push=self.do_xcom_push, ), method_name="execute_complete", ) @@ -197,6 +229,22 @@ def execute_complete(self, context: Context, event: dict, **kwargs): if event["status"] == "error": raise AirflowException(event["message"]) + if self.get_logs: + pod_name = event["pod_name"] + pod_namespace = event["pod_namespace"] + self.pod = self.hook.get_pod(pod_name, pod_namespace) + if not self.pod: + raise PodNotFoundException("Could not find pod after resuming from deferral") + self._write_logs(self.pod) + + if self.do_xcom_push: + xcom_result = event["xcom_result"] + if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT: + self.log.info("xcom result file is empty.") + return None + self.log.info("xcom result: \n%s", xcom_result) + return json.loads(xcom_result) + @staticmethod def deserialize_job_template_file(path: str) -> k8s.V1Job: """ @@ -229,7 +277,9 @@ def on_kill(self) -> None: } if self.termination_grace_period is not None: kwargs.update(grace_period_seconds=self.termination_grace_period) - self.client.delete_namespaced_job(**kwargs) + self.job_client.delete_namespaced_job(**kwargs) + if self.pod: + super().on_kill() def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: """ @@ -254,6 +304,7 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: metadata=pod_template.metadata, spec=pod_template.spec, ) + self.pod_request_obj = pod_template job = k8s.V1Job( api_version="batch/v1", diff --git a/airflow/providers/cncf/kubernetes/triggers/job.py b/airflow/providers/cncf/kubernetes/triggers/job.py index f229017df1499..d8d1db3567cea 100644 --- a/airflow/providers/cncf/kubernetes/triggers/job.py +++ b/airflow/providers/cncf/kubernetes/triggers/job.py @@ -16,10 +16,13 @@ # under the License. from __future__ import annotations +import asyncio from functools import cached_property from typing import TYPE_CHECKING, Any, AsyncIterator -from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook +from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook +from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager +from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: @@ -32,32 +35,49 @@ class KubernetesJobTrigger(BaseTrigger): :param job_name: The name of the job. :param job_namespace: The namespace of the job. + :param pod_name: The name of the Pod. + :param pod_namespace: The namespace of the Pod. + :param base_container_name: The name of the base container in the pod. :param kubernetes_conn_id: The :ref:`kubernetes connection id ` for the Kubernetes cluster. :param cluster_context: Context that points to kubernetes cluster. :param config_file: Path to kubeconfig file. :param poll_interval: Polling period in seconds to check for the status. :param in_cluster: run kubernetes client with in_cluster configuration. + :param get_logs: get the stdout of the base container as logs of the tasks. + :param do_xcom_push: If True, the content of the file + /airflow/xcom/return.json in the container will also be pushed to an + XCom when the container completes. """ def __init__( self, job_name: str, job_namespace: str, + pod_name: str, + pod_namespace: str, + base_container_name: str, kubernetes_conn_id: str | None = None, poll_interval: float = 10.0, cluster_context: str | None = None, config_file: str | None = None, in_cluster: bool | None = None, + get_logs: bool = True, + do_xcom_push: bool = False, ): super().__init__() self.job_name = job_name self.job_namespace = job_namespace + self.pod_name = pod_name + self.pod_namespace = pod_namespace + self.base_container_name = base_container_name self.kubernetes_conn_id = kubernetes_conn_id self.poll_interval = poll_interval self.cluster_context = cluster_context self.config_file = config_file self.in_cluster = in_cluster + self.get_logs = get_logs + self.do_xcom_push = do_xcom_push def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize KubernetesCreateJobTrigger arguments and classpath.""" @@ -66,16 +86,36 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "job_name": self.job_name, "job_namespace": self.job_namespace, + "pod_name": self.pod_name, + "pod_namespace": self.pod_namespace, + "base_container_name": self.base_container_name, "kubernetes_conn_id": self.kubernetes_conn_id, "poll_interval": self.poll_interval, "cluster_context": self.cluster_context, "config_file": self.config_file, "in_cluster": self.in_cluster, + "get_logs": self.get_logs, + "do_xcom_push": self.do_xcom_push, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job status and yield a TriggerEvent.""" + if self.get_logs or self.do_xcom_push: + pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace) + if self.do_xcom_push: + await self.hook.wait_until_container_complete( + name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name + ) + self.log.info("Checking if xcom sidecar container is started.") + await self.hook.wait_until_container_started( + name=self.pod_name, + namespace=self.pod_namespace, + container_name=PodDefaults.SIDECAR_CONTAINER_NAME, + ) + self.log.info("Extracting result from xcom sidecar container.") + loop = asyncio.get_running_loop() + xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace) job_dict = job.to_dict() error_message = self.hook.is_job_failed(job=job) @@ -83,11 +123,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] { "name": job.metadata.name, "namespace": job.metadata.namespace, + "pod_name": pod.metadata.name if self.get_logs else None, + "pod_namespace": pod.metadata.namespace if self.get_logs else None, "status": "error" if error_message else "success", "message": f"Job failed with error: {error_message}" if error_message else "Job completed successfully", "job": job_dict, + "xcom_result": xcom_result if self.do_xcom_push else None, } ) @@ -99,3 +142,13 @@ def hook(self) -> AsyncKubernetesHook: config_file=self.config_file, cluster_context=self.cluster_context, ) + + @cached_property + def pod_manager(self) -> PodManager: + sync_hook = KubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, + config_file=self.config_file, + cluster_context=self.cluster_context, + ) + return PodManager(kube_client=sync_hook.core_v1_client) diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index f9a75d8b28b1d..e4715ded3343e 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -943,13 +943,26 @@ def execute_deferrable(self): ssl_ca_cert=self._ssl_ca_cert, job_name=self.job.metadata.name, # type: ignore[union-attr] job_namespace=self.job.metadata.namespace, # type: ignore[union-attr] + pod_name=self.pod.metadata.name, # type: ignore[union-attr] + pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr] + base_container_name=self.base_container_name, gcp_conn_id=self.gcp_conn_id, poll_interval=self.job_poll_interval, impersonation_chain=self.impersonation_chain, + get_logs=self.get_logs, + do_xcom_push=self.do_xcom_push, ), method_name="execute_complete", + kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert}, ) + def execute_complete(self, context: Context, event: dict, **kwargs): + # It is required for hook to be initialized + self._cluster_url = kwargs["cluster_url"] + self._ssl_ca_cert = kwargs["ssl_ca_cert"] + + return super().execute_complete(context, event) + class GKEDescribeJobOperator(GoogleCloudBaseOperator): """ diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index f05bb0dc6c731..e4998d6957d69 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -26,10 +26,12 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger -from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction +from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager +from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.providers.google.cloud.hooks.kubernetes_engine import ( GKEAsyncHook, GKEKubernetesAsyncHook, + GKEKubernetesHook, ) from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -255,18 +257,28 @@ def __init__( ssl_ca_cert: str, job_name: str, job_namespace: str, + pod_name: str, + pod_namespace: str, + base_container_name: str, gcp_conn_id: str = "google_cloud_default", poll_interval: float = 2, impersonation_chain: str | Sequence[str] | None = None, + get_logs: bool = True, + do_xcom_push: bool = False, ) -> None: super().__init__() self.cluster_url = cluster_url self.ssl_ca_cert = ssl_ca_cert self.job_name = job_name self.job_namespace = job_namespace + self.pod_name = pod_name + self.pod_namespace = pod_namespace + self.base_container_name = base_container_name self.gcp_conn_id = gcp_conn_id self.poll_interval = poll_interval self.impersonation_chain = impersonation_chain + self.get_logs = get_logs + self.do_xcom_push = do_xcom_push def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize KubernetesCreateJobTrigger arguments and classpath.""" @@ -277,14 +289,34 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "ssl_ca_cert": self.ssl_ca_cert, "job_name": self.job_name, "job_namespace": self.job_namespace, + "pod_name": self.pod_name, + "pod_namespace": self.pod_namespace, + "base_container_name": self.base_container_name, "gcp_conn_id": self.gcp_conn_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, + "get_logs": self.get_logs, + "do_xcom_push": self.do_xcom_push, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job status and yield a TriggerEvent.""" + if self.get_logs or self.do_xcom_push: + pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace) + if self.do_xcom_push: + await self.hook.wait_until_container_complete( + name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name + ) + self.log.info("Checking if xcom sidecar container is started.") + await self.hook.wait_until_container_started( + name=self.pod_name, + namespace=self.pod_namespace, + container_name=PodDefaults.SIDECAR_CONTAINER_NAME, + ) + self.log.info("Extracting result from xcom sidecar container.") + loop = asyncio.get_running_loop() + xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace) job_dict = job.to_dict() error_message = self.hook.is_job_failed(job=job) @@ -294,9 +326,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] { "name": job.metadata.name, "namespace": job.metadata.namespace, + "pod_name": pod.metadata.name if self.get_logs else None, + "pod_namespace": pod.metadata.namespace if self.get_logs else None, "status": status, "message": message, "job": job_dict, + "xcom_result": xcom_result if self.do_xcom_push else None, } ) @@ -308,3 +343,13 @@ def hook(self) -> GKEKubernetesAsyncHook: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + + @cached_property + def pod_manager(self) -> PodManager: + sync_hook = GKEKubernetesHook( + gcp_conn_id=self.gcp_conn_id, + cluster_url=self.cluster_url, + ssl_ca_cert=self.ssl_ca_cert, + impersonation_chain=self.impersonation_chain, + ) + return PodManager(kube_client=sync_hook.core_v1_client) diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index c5aa62e1439e2..26926e7e22834 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -50,6 +50,7 @@ POD_NAME = "test-pod" NAMESPACE = "test-namespace" JOB_NAME = "test-job" +CONTAINER_NAME = "test-container" POLL_INTERVAL = 100 @@ -921,3 +922,81 @@ async def test_wait_until_job_complete( mock_is_job_complete.assert_has_calls([mock.call(job=mock_job_0), mock.call(job=mock_job_1)]) mock_sleep.assert_awaited_once_with(10) assert job_actual == mock_job_1 + + @pytest.mark.asyncio + @mock.patch(HOOK_MODULE + ".asyncio.sleep") + @mock.patch(HOOK_MODULE + ".container_is_completed") + @mock.patch(KUBE_ASYNC_HOOK.format("get_pod")) + async def test_wait_until_container_complete( + self, mock_get_pod, mock_container_is_completed, mock_sleep, kube_config_loader + ): + mock_pod_0, mock_pod_1 = mock.MagicMock(), mock.MagicMock() + mock_get_pod.side_effect = mock.AsyncMock(side_effect=[mock_pod_0, mock_pod_1]) + mock_container_is_completed.side_effect = [False, True] + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + await hook.wait_until_container_complete( + name=POD_NAME, + namespace=NAMESPACE, + container_name=CONTAINER_NAME, + poll_interval=10, + ) + + mock_get_pod.assert_has_awaits( + [ + mock.call(name=POD_NAME, namespace=NAMESPACE), + mock.call(name=POD_NAME, namespace=NAMESPACE), + ] + ) + mock_container_is_completed.assert_has_calls( + [ + mock.call(pod=mock_pod_0, container_name=CONTAINER_NAME), + mock.call(pod=mock_pod_1, container_name=CONTAINER_NAME), + ] + ) + mock_sleep.assert_awaited_once_with(10) + + @pytest.mark.asyncio + @mock.patch(HOOK_MODULE + ".asyncio.sleep") + @mock.patch(HOOK_MODULE + ".container_is_running") + @mock.patch(KUBE_ASYNC_HOOK.format("get_pod")) + async def test_wait_until_container_started( + self, mock_get_pod, mock_container_is_running, mock_sleep, kube_config_loader + ): + mock_pod_0, mock_pod_1 = mock.MagicMock(), mock.MagicMock() + mock_get_pod.side_effect = mock.AsyncMock(side_effect=[mock_pod_0, mock_pod_1]) + mock_container_is_running.side_effect = [False, True] + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + await hook.wait_until_container_started( + name=POD_NAME, + namespace=NAMESPACE, + container_name=CONTAINER_NAME, + poll_interval=10, + ) + + mock_get_pod.assert_has_awaits( + [ + mock.call(name=POD_NAME, namespace=NAMESPACE), + mock.call(name=POD_NAME, namespace=NAMESPACE), + ] + ) + mock_container_is_running.assert_has_calls( + [ + mock.call(pod=mock_pod_0, container_name=CONTAINER_NAME), + mock.call(pod=mock_pod_1, container_name=CONTAINER_NAME), + ] + ) + mock_sleep.assert_awaited_once_with(10) diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py index d776da248263e..60322672e218b 100644 --- a/tests/providers/cncf/kubernetes/operators/test_job.py +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -45,6 +45,10 @@ JOB_NAMESPACE = "test-namespace" JOB_POLL_INTERVAL = 20.0 KUBERNETES_CONN_ID = "test-conn_id" +POD_NAME = "test-pod" +POD_NAMESPACE = "test-namespace" +TEST_XCOM_RESULT = '{"result": "test-xcom-result"}' +POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" def create_context(task, persist_to_db=False, map_index=None): @@ -466,10 +470,11 @@ def test_task_id_as_name_dag_id_is_ignored(self): assert re.match(r"job-a-very-reasonable-task-name-[a-z0-9-]+", job.metadata.name) is not None @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) - def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj): + def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod): mock_hook.return_value.is_job_failed.return_value = False mock_job_request_obj = mock_build_job_request_obj.return_value mock_job_expected = mock_create_job.return_value @@ -498,12 +503,18 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj): assert not mock_hook.wait_until_job_complete.called @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.execute_deferrable")) @patch(HOOK_CLASS) def test_execute_in_deferrable( - self, mock_hook, mock_execute_deferrable, mock_create_job, mock_build_job_request_obj + self, + mock_hook, + mock_execute_deferrable, + mock_create_job, + mock_build_job_request_obj, + mock_get_or_create_pod, ): mock_hook.return_value.is_job_failed.return_value = False mock_job_request_obj = mock_build_job_request_obj.return_value @@ -534,10 +545,13 @@ def test_execute_in_deferrable( assert not mock_hook.wait_until_job_complete.called @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) - def test_execute_fail(self, mock_hook, mock_create_job, mock_build_job_request_obj): + def test_execute_fail( + self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod + ): mock_hook.return_value.is_job_failed.return_value = "Error" op = KubernetesJobOperator( @@ -560,6 +574,10 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): mock_job.metadata.name = JOB_NAME mock_job.metadata.namespace = JOB_NAMESPACE + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = POD_NAMESPACE + mock_trigger_instance = mock_trigger.return_value op = KubernetesJobOperator( @@ -573,6 +591,7 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): deferrable=True, ) op.job = mock_job + op.pod = mock_pod actual_result = op.execute_deferrable() @@ -583,20 +602,29 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): mock_trigger.assert_called_once_with( job_name=JOB_NAME, job_namespace=JOB_NAMESPACE, + pod_name=POD_NAME, + pod_namespace=POD_NAMESPACE, + base_container_name=op.BASE_CONTAINER_NAME, kubernetes_conn_id=KUBERNETES_CONN_ID, cluster_context=mock_cluster_context, config_file=mock_config_file, in_cluster=mock_in_cluster, poll_interval=POLL_INTERVAL, + get_logs=True, + do_xcom_push=False, ) assert actual_result is None - @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(f"{HOOK_CLASS}.wait_until_job_complete") def test_wait_until_job_complete( - self, mock_wait_until_job_complete, mock_create_job, mock_build_job_request_obj + self, + mock_wait_until_job_complete, + mock_create_job, + mock_build_job_request_obj, + mock_get_or_create_pod, ): mock_job_expected = mock_create_job.return_value mock_ti = mock.MagicMock() @@ -614,17 +642,32 @@ def test_wait_until_job_complete( job_poll_interval=POLL_INTERVAL, ) - @pytest.mark.non_db_test_override - def test_execute_complete(self): + @pytest.mark.parametrize("do_xcom_push", [True, False]) + @pytest.mark.parametrize("get_logs", [True, False]) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._write_logs")) + def test_execute_complete(self, mocked_write_logs, get_logs, do_xcom_push): mock_ti = mock.MagicMock() context = {"ti": mock_ti} mock_job = mock.MagicMock() - event = {"job": mock_job, "status": "success"} + event = { + "job": mock_job, + "status": "success", + "pod_name": POD_NAME if get_logs else None, + "pod_namespace": POD_NAMESPACE if get_logs else None, + "xcom_result": TEST_XCOM_RESULT if do_xcom_push else None, + } - KubernetesJobOperator(task_id="test_task_id").execute_complete(context=context, event=event) + KubernetesJobOperator( + task_id="test_task_id", get_logs=get_logs, do_xcom_push=do_xcom_push + ).execute_complete(context=context, event=event) mock_ti.xcom_push.assert_called_once_with(key="job", value=mock_job) + if get_logs: + mocked_write_logs.assert_called_once() + else: + mocked_write_logs.assert_not_called() + @pytest.mark.non_db_test_override def test_execute_complete_fail(self): mock_ti = mock.MagicMock() @@ -638,7 +681,7 @@ def test_execute_complete_fail(self): mock_ti.xcom_push.assert_called_once_with(key="job", value=mock_job) @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) @patch(HOOK_CLASS) def test_on_kill(self, mock_hook, mock_client): mock_job = mock.MagicMock() @@ -659,7 +702,7 @@ def test_on_kill(self, mock_hook, mock_client): mock_serialize.assert_called_once_with(mock_job) @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) @patch(HOOK_CLASS) def test_on_kill_termination_grace_period(self, mock_hook, mock_client): mock_job = mock.MagicMock() @@ -695,6 +738,49 @@ def test_on_kill_none_job(self, mock_hook, mock_client): mock_client.delete_namespaced_job.assert_not_called() mock_serialize.assert_not_called() + @pytest.mark.parametrize("do_xcom_push", [True, False]) + @pytest.mark.parametrize("get_logs", [True, False]) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs") + @patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") + @patch(f"{POD_MANAGER_CLASS}.await_container_completion") + @patch(f"{HOOK_CLASS}.wait_until_job_complete") + def test_execute_xcom_and_logs( + self, + mock_wait_until_job_complete, + mock_await_container_completion, + mock_await_xcom_sidecar_container_start, + mocked_fetch_logs, + mock_create_job, + mock_build_job_request_obj, + mock_get_or_create_pod, + mock_extract_xcom, + get_logs, + do_xcom_push, + ): + mock_ti = mock.MagicMock() + op = KubernetesJobOperator( + task_id="test_task_id", + wait_until_job_complete=True, + job_poll_interval=POLL_INTERVAL, + get_logs=get_logs, + do_xcom_push=do_xcom_push, + ) + op.execute(context=dict(ti=mock_ti)) + + if do_xcom_push: + mock_extract_xcom.assert_called_once() + else: + mock_extract_xcom.assert_not_called() + + if get_logs: + mocked_fetch_logs.assert_called_once() + else: + mocked_fetch_logs.assert_not_called() + @pytest.mark.db_test @pytest.mark.execution_timeout(300) diff --git a/tests/providers/cncf/kubernetes/triggers/test_job.py b/tests/providers/cncf/kubernetes/triggers/test_job.py index 6124f5471c889..7a85e0905369a 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_job.py +++ b/tests/providers/cncf/kubernetes/triggers/test_job.py @@ -28,12 +28,16 @@ TRIGGER_CLASS = TRIGGER_PATH.format("KubernetesJobTrigger") HOOK_PATH = "airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook" JOB_NAME = "test-job-name" +POD_NAME = "test-pod-name" +CONTAINER_NAME = "test-container-name" NAMESPACE = "default" CONN_ID = "test_kubernetes_conn_id" POLL_INTERVAL = 2 CLUSTER_CONTEXT = "test-context" CONFIG_FILE = "/path/to/config/file" IN_CLUSTER = False +GET_LOGS = True +XCOM_PUSH = False @pytest.fixture @@ -41,11 +45,16 @@ def trigger(): return KubernetesJobTrigger( job_name=JOB_NAME, job_namespace=NAMESPACE, + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=CONTAINER_NAME, kubernetes_conn_id=CONN_ID, poll_interval=POLL_INTERVAL, cluster_context=CLUSTER_CONTEXT, config_file=CONFIG_FILE, in_cluster=IN_CLUSTER, + get_logs=GET_LOGS, + do_xcom_push=XCOM_PUSH, ) @@ -57,11 +66,16 @@ def test_serialize(self, trigger): assert kwargs_dict == { "job_name": JOB_NAME, "job_namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, + "base_container_name": CONTAINER_NAME, "kubernetes_conn_id": CONN_ID, "poll_interval": POLL_INTERVAL, "cluster_context": CLUSTER_CONTEXT, "config_file": CONFIG_FILE, "in_cluster": IN_CLUSTER, + "get_logs": GET_LOGS, + "do_xcom_push": XCOM_PUSH, } @pytest.mark.asyncio @@ -72,6 +86,11 @@ async def test_run_success(self, mock_hook, trigger): mock_job.metadata.namespace = NAMESPACE mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = NAMESPACE + mock_hook.get_pod.side_effect = mock.AsyncMock(return_value=mock_pod) + mock_is_job_failed = mock_hook.is_job_failed mock_is_job_failed.return_value = False @@ -86,9 +105,12 @@ async def test_run_success(self, mock_hook, trigger): { "name": JOB_NAME, "namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, "status": "success", "message": "Job completed successfully", "job": mock_job_dict, + "xcom_result": None, } ) @@ -100,6 +122,11 @@ async def test_run_fail(self, mock_hook, trigger): mock_job.metadata.namespace = NAMESPACE mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = NAMESPACE + mock_hook.get_pod.side_effect = mock.AsyncMock(return_value=mock_pod) + mock_is_job_failed = mock_hook.is_job_failed mock_is_job_failed.return_value = "Error" @@ -114,9 +141,12 @@ async def test_run_fail(self, mock_hook, trigger): { "name": JOB_NAME, "namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, "status": "error", "message": "Job failed with error: Error", "job": mock_job_dict, + "xcom_result": None, } ) diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index f8416c2189f4a..c1d9655e45fbe 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -71,6 +71,7 @@ TASK_NAME = "test-task-name" JOB_NAME = "test-job" +POD_NAME = "test-pod" NAMESPACE = ("default",) IMAGE = "bash" JOB_POLL_INTERVAL = 20.0 @@ -893,6 +894,12 @@ def test_execute_deferrable(self, mock_trigger): mock_metadata = mock_job.metadata mock_metadata.name = TASK_NAME mock_metadata.namespace = NAMESPACE + + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = NAMESPACE + op.pod = mock_pod + with mock.patch.object(op, "defer") as mock_defer: op.execute_deferrable() @@ -901,13 +908,19 @@ def test_execute_deferrable(self, mock_trigger): ssl_ca_cert=SSL_CA_CERT, job_name=TASK_NAME, job_namespace=NAMESPACE, + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=op.BASE_CONTAINER_NAME, gcp_conn_id="google_cloud_default", poll_interval=JOB_POLL_INTERVAL, impersonation_chain=None, + get_logs=True, + do_xcom_push=False, ) mock_defer.assert_called_once_with( trigger=mock_trigger_instance, method_name="execute_complete", + kwargs={"cluster_url": CLUSTER_URL, "ssl_ca_cert": SSL_CA_CERT}, ) def test_config_file_throws_error(self): diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 88f92e780caf2..a41eab654bbd2 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -55,6 +55,7 @@ FAILED_RESULT_MSG = "Test message that appears when trigger have failed event." BASE_CONTAINER_NAME = "base" ON_FINISH_ACTION = "delete_pod" +XCOM_PUSH = False OPERATION_NAME = "test-operation-name" PROJECT_ID = "test-project-id" @@ -92,9 +93,14 @@ def job_trigger(): ssl_ca_cert=SSL_CA_CERT, job_name=JOB_NAME, job_namespace=NAMESPACE, + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=BASE_CONTAINER_NAME, gcp_conn_id=GCP_CONN_ID, poll_interval=POLL_INTERVAL, impersonation_chain=IMPERSONATION_CHAIN, + get_logs=GET_LOGS, + do_xcom_push=XCOM_PUSH, ) @@ -472,9 +478,14 @@ def test_serialize(self, job_trigger): "ssl_ca_cert": SSL_CA_CERT, "job_name": JOB_NAME, "job_namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, + "base_container_name": BASE_CONTAINER_NAME, "gcp_conn_id": GCP_CONN_ID, "poll_interval": POLL_INTERVAL, "impersonation_chain": IMPERSONATION_CHAIN, + "get_logs": GET_LOGS, + "do_xcom_push": XCOM_PUSH, } @pytest.mark.asyncio @@ -485,6 +496,11 @@ async def test_run_success(self, mock_hook, job_trigger): mock_job.metadata.namespace = NAMESPACE mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = NAMESPACE + mock_hook.get_pod.side_effect = mock.AsyncMock(return_value=mock_pod) + mock_is_job_failed = mock_hook.is_job_failed mock_is_job_failed.return_value = False @@ -499,9 +515,12 @@ async def test_run_success(self, mock_hook, job_trigger): { "name": JOB_NAME, "namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, "status": "success", "message": "Job completed successfully", "job": mock_job_dict, + "xcom_result": None, } ) @@ -513,6 +532,11 @@ async def test_run_fail(self, mock_hook, job_trigger): mock_job.metadata.namespace = NAMESPACE mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + mock_pod = mock.MagicMock() + mock_pod.metadata.name = POD_NAME + mock_pod.metadata.namespace = NAMESPACE + mock_hook.get_pod.side_effect = mock.AsyncMock(return_value=mock_pod) + mock_is_job_failed = mock_hook.is_job_failed mock_is_job_failed.return_value = "Error" @@ -527,9 +551,12 @@ async def test_run_fail(self, mock_hook, job_trigger): { "name": JOB_NAME, "namespace": NAMESPACE, + "pod_name": POD_NAME, + "pod_namespace": NAMESPACE, "status": "error", "message": "Job failed with error: Error", "job": mock_job_dict, + "xcom_result": None, } )