diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index 7c19421594aa2..79e9b83581243 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -22,16 +22,17 @@ import json import logging import os +import warnings from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal from kubernetes.client import BatchV1Api, models as k8s from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( add_unique_suffix, @@ -81,6 +82,17 @@ class KubernetesJobOperator(KubernetesPodOperator): Used if the parameter `wait_until_job_complete` set True. :param deferrable: Run operator in the deferrable mode. Note that the parameter `wait_until_job_complete` must be set True. + :param on_kill_propagation_policy: Whether and how garbage collection will be performed. Default is 'Foreground'. + Acceptable values are: + 'Orphan' - orphan the dependents; + 'Background' - allow the garbage collector to delete the dependents in the background; + 'Foreground' - a cascading policy that deletes all dependents in the foreground. + Default value is 'Foreground'. + :param discover_pods_retry_number: Number of time list_namespaced_pod will be performed to discover + already running pods. + :param unwrap_single: Unwrap single result from the pod. For example, when set to `True` - if the XCom + result should be `['res']`, the final result would be `'res'`. Default is True to support backward + compatibility. """ template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields)) @@ -101,8 +113,12 @@ def __init__( wait_until_job_complete: bool = False, job_poll_interval: float = 10, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + on_kill_propagation_policy: Literal["Foreground", "Background", "Orphan"] = "Foreground", + discover_pods_retry_number: int = 3, + unwrap_single: bool = True, **kwargs, ) -> None: + self._pod = None super().__init__(**kwargs) self.job_template_file = job_template_file self.full_job_spec = full_job_spec @@ -119,6 +135,22 @@ def __init__( self.wait_until_job_complete = wait_until_job_complete self.job_poll_interval = job_poll_interval self.deferrable = deferrable + self.on_kill_propagation_policy = on_kill_propagation_policy + self.discover_pods_retry_number = discover_pods_retry_number + self.unwrap_single = unwrap_single + + @property + def pod(self): + warnings.warn( + "`pod` parameter is deprecated, please use `pods`", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.pods[0] if self.pods else None + + @pod.setter + def pod(self, value): + self._pod = value @cached_property def _incluster_namespace(self): @@ -167,12 +199,16 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) - self.pod: k8s.V1Pod | None - if self.pod is None: - self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` - pod_request_obj=self.pod_request_obj, - context=context, - ) + self.pods: Sequence[k8s.V1Pod] | None = None + if self.parallelism is None and self.pod is None: + self.pods = [ + self.get_or_create_pod( + pod_request_obj=self.pod_request_obj, + context=context, + ) + ] + else: + self.pods = self.get_pods(pod_request_obj=self.pod_request_obj, context=context) if self.wait_until_job_complete and self.deferrable: self.execute_deferrable() @@ -180,22 +216,25 @@ def execute(self, context: Context): if self.wait_until_job_complete: if self.do_xcom_push: - self.pod_manager.await_container_completion( - pod=self.pod, container_name=self.base_container_name - ) - self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) - xcom_result = self.extract_xcom(pod=self.pod) + xcom_result = [] + for pod in self.pods: + self.pod_manager.await_container_completion( + pod=pod, container_name=self.base_container_name + ) + self.pod_manager.await_xcom_sidecar_container_start(pod=pod) + xcom_result.append(self.extract_xcom(pod=pod)) self.job = self.hook.wait_until_job_complete( job_name=self.job.metadata.name, namespace=self.job.metadata.namespace, job_poll_interval=self.job_poll_interval, ) if self.get_logs: - self.pod_manager.fetch_requested_container_logs( - pod=self.pod, - containers=self.container_logs, - follow_logs=True, - ) + for pod in self.pods: + self.pod_manager.fetch_requested_container_logs( + pod=pod, + containers=self.container_logs, + follow_logs=True, + ) ti.xcom_push(key="job", value=self.job.to_dict()) if self.wait_until_job_complete: @@ -211,8 +250,8 @@ def execute_deferrable(self): trigger=KubernetesJobTrigger( job_name=self.job.metadata.name, job_namespace=self.job.metadata.namespace, - pod_name=self.pod.metadata.name, - pod_namespace=self.pod.metadata.namespace, + pod_names=[pod.metadata.name for pod in self.pods], + pod_namespace=self.pods[0].metadata.namespace, base_container_name=self.base_container_name, kubernetes_conn_id=self.kubernetes_conn_id, cluster_context=self.cluster_context, @@ -232,20 +271,23 @@ def execute_complete(self, context: Context, event: dict, **kwargs): raise AirflowException(event["message"]) if self.get_logs: - pod_name = event["pod_name"] - pod_namespace = event["pod_namespace"] - self.pod = self.hook.get_pod(pod_name, pod_namespace) - if not self.pod: - raise PodNotFoundException("Could not find pod after resuming from deferral") - self._write_logs(self.pod) + for pod_name in event["pod_names"]: + pod_namespace = event["pod_namespace"] + pod = self.hook.get_pod(pod_name, pod_namespace) + if not pod: + raise PodNotFoundException("Could not find pod after resuming from deferral") + self._write_logs(pod) if self.do_xcom_push: - xcom_result = event["xcom_result"] - if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT: - self.log.info("xcom result file is empty.") - return None - self.log.info("xcom result: \n%s", xcom_result) - return json.loads(xcom_result) + xcom_results: list[Any | None] = [] + for xcom_result in event["xcom_result"]: + if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT: + self.log.info("xcom result file is empty.") + xcom_results.append(None) + continue + self.log.info("xcom result: \n%s", xcom_result) + xcom_results.append(json.loads(xcom_result)) + return xcom_results[0] if self.unwrap_single and len(xcom_results) == 1 else xcom_results @staticmethod def deserialize_job_template_file(path: str) -> k8s.V1Job: @@ -275,12 +317,11 @@ def on_kill(self) -> None: kwargs = { "name": job.metadata.name, "namespace": job.metadata.namespace, + "propagation_policy": self.on_kill_propagation_policy, } if self.termination_grace_period is not None: kwargs.update(grace_period_seconds=self.termination_grace_period) self.job_client.delete_namespaced_job(**kwargs) - if self.pod: - super().on_kill() def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: """ @@ -400,6 +441,29 @@ def reconcile_job_specs( return None + def get_pods( + self, pod_request_obj: k8s.V1Pod, context: Context, *, exclude_checked: bool = True + ) -> Sequence[k8s.V1Pod]: + """Return an already-running pods if exists.""" + label_selector = self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + pod_list: Sequence[k8s.V1Pod] = [] + retry_number: int = 0 + + while len(pod_list) != self.parallelism or retry_number <= self.discover_pods_retry_number: + pod_list = self.client.list_namespaced_pod( + namespace=pod_request_obj.metadata.namespace, + label_selector=label_selector, + ).items + retry_number += 1 + + if len(pod_list) == 0: + raise AirflowException(f"No pods running with labels {label_selector}") + + for pod_instance in pod_list: + self.log_matching_pod(pod=pod_instance, context=context) + + return pod_list + class KubernetesDeleteJobOperator(BaseOperator): """ diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py index 359c30547881c..b60373c2d5339 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py @@ -17,10 +17,12 @@ from __future__ import annotations import asyncio +import warnings from collections.abc import AsyncIterator from functools import cached_property from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -36,7 +38,8 @@ class KubernetesJobTrigger(BaseTrigger): :param job_name: The name of the job. :param job_namespace: The namespace of the job. - :param pod_name: The name of the Pod. + :param pod_name: The name of the Pod. Parameter is deprecated, please use pod_names instead. + :param pod_names: The name of the Pods. :param pod_namespace: The namespace of the Pod. :param base_container_name: The name of the base container in the pod. :param kubernetes_conn_id: The :ref:`kubernetes connection id ` @@ -55,9 +58,10 @@ def __init__( self, job_name: str, job_namespace: str, - pod_name: str, + pod_names: list[str], pod_namespace: str, base_container_name: str, + pod_name: str | None = None, kubernetes_conn_id: str | None = None, poll_interval: float = 10.0, cluster_context: str | None = None, @@ -69,7 +73,13 @@ def __init__( super().__init__() self.job_name = job_name self.job_namespace = job_namespace - self.pod_name = pod_name + if pod_name is not None: + self._pod_name = pod_name + self.pod_names = [ + self.pod_name, + ] + else: + self.pod_names = pod_names self.pod_namespace = pod_namespace self.base_container_name = base_container_name self.kubernetes_conn_id = kubernetes_conn_id @@ -80,6 +90,15 @@ def __init__( self.get_logs = get_logs self.do_xcom_push = do_xcom_push + @property + def pod_name(self): + warnings.warn( + "`pod_name` parameter is deprecated, please use `pod_names`", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self._pod_name + def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize KubernetesCreateJobTrigger arguments and classpath.""" return ( @@ -87,7 +106,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "job_name": self.job_name, "job_namespace": self.job_namespace, - "pod_name": self.pod_name, + "pod_names": self.pod_names, "pod_namespace": self.pod_namespace, "base_container_name": self.base_container_name, "kubernetes_conn_id": self.kubernetes_conn_id, @@ -102,21 +121,23 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """Get current job status and yield a TriggerEvent.""" - if self.get_logs or self.do_xcom_push: - pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace) if self.do_xcom_push: - await self.hook.wait_until_container_complete( - name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name - ) - self.log.info("Checking if xcom sidecar container is started.") - await self.hook.wait_until_container_started( - name=self.pod_name, - namespace=self.pod_namespace, - container_name=PodDefaults.SIDECAR_CONTAINER_NAME, - ) - self.log.info("Extracting result from xcom sidecar container.") - loop = asyncio.get_running_loop() - xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) + xcom_results = [] + for pod_name in self.pod_names: + pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace) + await self.hook.wait_until_container_complete( + name=pod_name, namespace=self.pod_namespace, container_name=self.base_container_name + ) + self.log.info("Checking if xcom sidecar container is started.") + await self.hook.wait_until_container_started( + name=pod_name, + namespace=self.pod_namespace, + container_name=PodDefaults.SIDECAR_CONTAINER_NAME, + ) + self.log.info("Extracting result from xcom sidecar container.") + loop = asyncio.get_running_loop() + xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) + xcom_results.append(xcom_result) job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace) job_dict = job.to_dict() error_message = self.hook.is_job_failed(job=job) @@ -124,14 +145,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "name": job.metadata.name, "namespace": job.metadata.namespace, - "pod_name": pod.metadata.name if self.get_logs else None, - "pod_namespace": pod.metadata.namespace if self.get_logs else None, + "pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None, + "pod_namespace": self.pod_namespace if self.get_logs else None, "status": "error" if error_message else "success", "message": f"Job failed with error: {error_message}" if error_message else "Job completed successfully", "job": job_dict, - "xcom_result": xcom_result if self.do_xcom_push else None, + "xcom_result": xcom_results if self.do_xcom_push else None, } ) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index 130a87cbbe3b1..6caf69e04812c 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -26,7 +26,7 @@ import pytest from kubernetes.client import ApiClient, models as k8s -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import DAG, DagModel, DagRun, TaskInstance from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.cncf.kubernetes.operators.job import ( @@ -52,6 +52,7 @@ POD_NAMESPACE = "test-namespace" TEST_XCOM_RESULT = '{"result": "test-xcom-result"}' POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" +ON_KILL_PROPAGATION_POLICY = "Foreground" def create_context(task, persist_to_db=False, map_index=None): @@ -508,6 +509,45 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m op = KubernetesJobOperator( task_id="test_task_id", ) + with pytest.warns(AirflowProviderDeprecationWarning): + execute_result = op.execute(context=context) + + mock_build_job_request_obj.assert_called_once_with(context) + mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj) + mock_ti.xcom_push.assert_has_calls( + [ + mock.call(key="job_name", value=mock_job_expected.metadata.name), + mock.call(key="job_namespace", value=mock_job_expected.metadata.namespace), + mock.call(key="job", value=mock_job_expected.to_dict.return_value), + ] + ) + + assert op.job_request_obj == mock_job_request_obj + assert op.job == mock_job_expected + assert not op.wait_until_job_complete + assert execute_result is None + assert not mock_hook.wait_until_job_complete.called + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_execute_with_parallelism( + self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods + ): + mock_hook.return_value.is_job_failed.return_value = False + mock_job_request_obj = mock_build_job_request_obj.return_value + mock_job_expected = mock_create_job.return_value + mock_get_pods.return_value = [mock.MagicMock(), mock.MagicMock()] + mock_pods_expected = mock_get_pods.return_value + mock_ti = mock.MagicMock() + context = dict(ti=mock_ti) + + op = KubernetesJobOperator( + task_id="test_task_id", + parallelism=2, + ) execute_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) @@ -522,6 +562,9 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m assert op.job_request_obj == mock_job_request_obj assert op.job == mock_job_expected + assert op.pods == mock_pods_expected + with pytest.warns(AirflowProviderDeprecationWarning): + assert op.pod is mock_pods_expected[0] assert not op.wait_until_job_complete assert execute_result is None assert not mock_hook.wait_until_job_complete.called @@ -551,7 +594,8 @@ def test_execute_in_deferrable( wait_until_job_complete=True, deferrable=True, ) - actual_result = op.execute(context=context) + with pytest.warns(AirflowProviderDeprecationWarning): + actual_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj) @@ -583,8 +627,9 @@ def test_execute_fail( wait_until_job_complete=True, ) - with pytest.raises(AirflowException): - op.execute(context=dict(ti=mock.MagicMock())) + with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.raises(AirflowException): + op.execute(context=dict(ti=mock.MagicMock())) @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer")) @@ -616,6 +661,71 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): ) op.job = mock_job op.pod = mock_pod + op.pods = [ + mock_pod, + ] + + actual_result = op.execute_deferrable() + + mock_execute_deferrable.assert_called_once_with( + trigger=mock_trigger_instance, + method_name="execute_complete", + ) + mock_trigger.assert_called_once_with( + job_name=JOB_NAME, + job_namespace=JOB_NAMESPACE, + pod_names=[ + POD_NAME, + ], + pod_namespace=POD_NAMESPACE, + base_container_name=op.BASE_CONTAINER_NAME, + kubernetes_conn_id=KUBERNETES_CONN_ID, + cluster_context=mock_cluster_context, + config_file=mock_config_file, + in_cluster=mock_in_cluster, + poll_interval=POLL_INTERVAL, + get_logs=True, + do_xcom_push=False, + ) + assert actual_result is None + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobTrigger")) + def test_execute_deferrable_with_parallelism(self, mock_trigger, mock_execute_deferrable): + mock_cluster_context = mock.MagicMock() + mock_config_file = mock.MagicMock() + mock_in_cluster = mock.MagicMock() + + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + + pod_name_1 = POD_NAME + "-1" + mock_pod_1 = mock.MagicMock() + mock_pod_1.metadata.name = pod_name_1 + mock_pod_1.metadata.namespace = POD_NAMESPACE + + pod_name_2 = POD_NAME + "-2" + mock_pod_2 = mock.MagicMock() + mock_pod_2.metadata.name = pod_name_2 + mock_pod_2.metadata.namespace = POD_NAMESPACE + + mock_trigger_instance = mock_trigger.return_value + + op = KubernetesJobOperator( + task_id="test_task_id", + kubernetes_conn_id=KUBERNETES_CONN_ID, + cluster_context=mock_cluster_context, + config_file=mock_config_file, + in_cluster=mock_in_cluster, + job_poll_interval=POLL_INTERVAL, + parallelism=2, + wait_until_job_complete=True, + deferrable=True, + ) + op.job = mock_job + op.pods = [mock_pod_1, mock_pod_2] actual_result = op.execute_deferrable() @@ -626,7 +736,7 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): mock_trigger.assert_called_once_with( job_name=JOB_NAME, job_namespace=JOB_NAMESPACE, - pod_name=POD_NAME, + pod_names=[pod_name_1, pod_name_2], pod_namespace=POD_NAMESPACE, base_container_name=op.BASE_CONTAINER_NAME, kubernetes_conn_id=KUBERNETES_CONN_ID, @@ -656,7 +766,8 @@ def test_wait_until_job_complete( op = KubernetesJobOperator( task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL ) - op.execute(context=dict(ti=mock_ti)) + with pytest.warns(AirflowProviderDeprecationWarning): + op.execute(context=dict(ti=mock_ti)) assert op.wait_until_job_complete assert op.job_poll_interval == POLL_INTERVAL @@ -676,9 +787,17 @@ def test_execute_complete(self, mocked_write_logs, get_logs, do_xcom_push): event = { "job": mock_job, "status": "success", - "pod_name": POD_NAME if get_logs else None, + "pod_names": [ + POD_NAME, + ] + if get_logs + else None, "pod_namespace": POD_NAMESPACE if get_logs else None, - "xcom_result": TEST_XCOM_RESULT if do_xcom_push else None, + "xcom_result": [ + TEST_XCOM_RESULT, + ] + if do_xcom_push + else None, } KubernetesJobOperator( @@ -718,6 +837,7 @@ def test_on_kill(self, mock_client): mock_client.delete_namespaced_job.assert_called_once_with( name=JOB_NAME, namespace=JOB_NAMESPACE, + propagation_policy=ON_KILL_PROPAGATION_POLICY, ) @pytest.mark.non_db_test_override @@ -737,6 +857,7 @@ def test_on_kill_termination_grace_period(self, mock_client): mock_client.delete_namespaced_job.assert_called_once_with( name=JOB_NAME, namespace=JOB_NAMESPACE, + propagation_policy=ON_KILL_PROPAGATION_POLICY, grace_period_seconds=mock_termination_grace_period, ) @@ -752,9 +873,11 @@ def test_on_kill_none_job(self, mock_hook, mock_client): mock_client.delete_namespaced_job.assert_not_called() mock_serialize.assert_not_called() + @pytest.mark.parametrize("parallelism", [None, 2]) @pytest.mark.parametrize("do_xcom_push", [True, False]) @pytest.mark.parametrize("get_logs", [True, False]) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @@ -771,10 +894,16 @@ def test_execute_xcom_and_logs( mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod, + mock_get_pods, mock_extract_xcom, get_logs, do_xcom_push, + parallelism, ): + if parallelism == 2: + mock_pod_1 = mock.MagicMock() + mock_pod_2 = mock.MagicMock() + mock_get_pods.return_value = [mock_pod_1, mock_pod_2] mock_ti = mock.MagicMock() op = KubernetesJobOperator( task_id="test_task_id", @@ -782,16 +911,26 @@ def test_execute_xcom_and_logs( job_poll_interval=POLL_INTERVAL, get_logs=get_logs, do_xcom_push=do_xcom_push, + parallelism=parallelism, ) - op.execute(context=dict(ti=mock_ti)) - if do_xcom_push: + if not parallelism: + with pytest.warns(AirflowProviderDeprecationWarning): + op.execute(context=dict(ti=mock_ti)) + else: + op.execute(context=dict(ti=mock_ti)) + + if do_xcom_push and not parallelism: mock_extract_xcom.assert_called_once() + elif do_xcom_push and parallelism is not None: + assert mock_extract_xcom.call_count == parallelism else: mock_extract_xcom.assert_not_called() - if get_logs: + if get_logs and not parallelism: mocked_fetch_logs.assert_called_once() + elif get_logs and parallelism is not None: + assert mocked_fetch_logs.call_count == parallelism else: mocked_fetch_logs.assert_not_called() diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py index 7a85e0905369a..4f6f6597fcfcc 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py @@ -45,7 +45,9 @@ def trigger(): return KubernetesJobTrigger( job_name=JOB_NAME, job_namespace=NAMESPACE, - pod_name=POD_NAME, + pod_names=[ + POD_NAME, + ], pod_namespace=NAMESPACE, base_container_name=CONTAINER_NAME, kubernetes_conn_id=CONN_ID, @@ -66,7 +68,9 @@ def test_serialize(self, trigger): assert kwargs_dict == { "job_name": JOB_NAME, "job_namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "base_container_name": CONTAINER_NAME, "kubernetes_conn_id": CONN_ID, @@ -105,7 +109,9 @@ async def test_run_success(self, mock_hook, trigger): { "name": JOB_NAME, "namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "status": "success", "message": "Job completed successfully", @@ -141,7 +147,9 @@ async def test_run_fail(self, mock_hook, trigger): { "name": JOB_NAME, "namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "status": "error", "message": "Job failed with error: Error", diff --git a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py index fc044568a4cd2..b2303374a5256 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -789,8 +789,8 @@ def execute_deferrable(self): ssl_ca_cert=self.ssl_ca_cert, job_name=self.job.metadata.name, job_namespace=self.job.metadata.namespace, - pod_name=self.pod.metadata.name, - pod_namespace=self.pod.metadata.namespace, + pod_names=[pod.metadata.name for pod in self.pods], + pod_namespace=self.pods[0].metadata.namespace, base_container_name=self.base_container_name, gcp_conn_id=self.gcp_conn_id, poll_interval=self.job_poll_interval, diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py index 08d30fcb3d6b0..3e45893165264 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -260,9 +260,10 @@ def __init__( ssl_ca_cert: str, job_name: str, job_namespace: str, - pod_name: str, + pod_names: list[str], pod_namespace: str, base_container_name: str, + pod_name: str | None = None, gcp_conn_id: str = "google_cloud_default", poll_interval: float = 2, impersonation_chain: str | Sequence[str] | None = None, @@ -274,7 +275,13 @@ def __init__( self.ssl_ca_cert = ssl_ca_cert self.job_name = job_name self.job_namespace = job_namespace - self.pod_name = pod_name + if pod_name is not None: + self._pod_name = pod_name + self.pod_names = [ + self.pod_name, + ] + else: + self.pod_names = pod_names self.pod_namespace = pod_namespace self.base_container_name = base_container_name self.gcp_conn_id = gcp_conn_id @@ -283,6 +290,15 @@ def __init__( self.get_logs = get_logs self.do_xcom_push = do_xcom_push + @property + def pod_name(self): + warnings.warn( + "`pod_name` parameter is deprecated, please use `pod_names`", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self._pod_name + def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize KubernetesCreateJobTrigger arguments and classpath.""" return ( @@ -292,7 +308,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "ssl_ca_cert": self.ssl_ca_cert, "job_name": self.job_name, "job_namespace": self.job_namespace, - "pod_name": self.pod_name, + "pod_names": self.pod_names, "pod_namespace": self.pod_namespace, "base_container_name": self.base_container_name, "gcp_conn_id": self.gcp_conn_id, @@ -305,8 +321,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """Get current job status and yield a TriggerEvent.""" - if self.get_logs or self.do_xcom_push: - pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace) if self.do_xcom_push: kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"] kubernetes_provider_name = kubernetes_provider.data["package-name"] @@ -318,22 +332,26 @@ async def run(self) -> AsyncIterator[TriggerEvent]: f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't " f"support this feature. Please upgrade it to version higher than or equal to {min_version}." ) - await self.hook.wait_until_container_complete( - name=self.pod_name, - namespace=self.pod_namespace, - container_name=self.base_container_name, - poll_interval=self.poll_interval, - ) - self.log.info("Checking if xcom sidecar container is started.") - await self.hook.wait_until_container_started( - name=self.pod_name, - namespace=self.pod_namespace, - container_name=PodDefaults.SIDECAR_CONTAINER_NAME, - poll_interval=self.poll_interval, - ) - self.log.info("Extracting result from xcom sidecar container.") - loop = asyncio.get_running_loop() - xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) + xcom_results = [] + for pod_name in self.pod_names: + pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace) + await self.hook.wait_until_container_complete( + name=pod_name, + namespace=self.pod_namespace, + container_name=self.base_container_name, + poll_interval=self.poll_interval, + ) + self.log.info("Checking if xcom sidecar container is started.") + await self.hook.wait_until_container_started( + name=pod_name, + namespace=self.pod_namespace, + container_name=PodDefaults.SIDECAR_CONTAINER_NAME, + poll_interval=self.poll_interval, + ) + self.log.info("Extracting result from xcom sidecar container.") + loop = asyncio.get_running_loop() + xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod) + xcom_results.append(xcom_result) job: V1Job = await self.hook.wait_until_job_complete( name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval ) @@ -345,12 +363,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "name": job.metadata.name, "namespace": job.metadata.namespace, - "pod_name": pod.metadata.name if self.get_logs else None, - "pod_namespace": pod.metadata.namespace if self.get_logs else None, + "pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None, + "pod_namespace": self.pod_namespace if self.get_logs else None, "status": status, "message": message, "job": job_dict, - "xcom_result": xcom_result if self.do_xcom_push else None, + "xcom_result": xcom_results if self.do_xcom_push else None, } ) diff --git a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py index 578a8e62b849b..ccf1fbabeaf54 100644 --- a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py +++ b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py @@ -48,8 +48,13 @@ JOB_NAME = "test-pi" JOB_NAME_DEF = "test-pi-def" +JOB_NAME_WITH_PARALLELISM = "test-pi-with-parallelism" +JOB_NAME_DEF_WITH_PARALLELISM = "test-pi-def-with-parallelism" JOB_NAMESPACE = "default" +PARALLELISM = 2 +COMPLETION_MODE = "Indexed" + with DAG( DAG_ID, schedule="@once", # Override to match your needs @@ -92,6 +97,45 @@ ) # [END howto_operator_gke_start_job_def] + # [START howto_operator_gke_start_job_parallelism] + job_task_with_parallelism = GKEStartJobOperator( + task_id="job_task_with_parallelism", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + namespace=JOB_NAMESPACE, + image="perl:5.34.0", + cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"], + name=JOB_NAME_WITH_PARALLELISM, + wait_until_job_complete=True, + parallelism=PARALLELISM, + completions=PARALLELISM, + completion_mode=COMPLETION_MODE, + get_logs=True, + do_xcom_push=True, + ) + # [END howto_operator_gke_start_job_with_parallelism] + + # [START howto_operator_gke_start_job_def_with_parallelism] + job_task_def_with_parallelism = GKEStartJobOperator( + task_id="job_task_def_with_parallelism", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + namespace=JOB_NAMESPACE, + image="perl:5.34.0", + cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"], + name=JOB_NAME_DEF_WITH_PARALLELISM, + wait_until_job_complete=True, + deferrable=True, + parallelism=PARALLELISM, + completions=PARALLELISM, + completion_mode=COMPLETION_MODE, + get_logs=True, + do_xcom_push=True, + ) + # [END howto_operator_gke_start_job_def_with_parallelism] + # [START howto_operator_gke_list_jobs] list_job_task = GKEListJobsOperator( task_id="list_job_task", project_id=GCP_PROJECT_ID, location=GCP_LOCATION, cluster_name=CLUSTER_NAME @@ -104,7 +148,7 @@ project_id=GCP_PROJECT_ID, location=GCP_LOCATION, job_name=job_task.output["job_name"], - namespace="default", + namespace=JOB_NAMESPACE, cluster_name=CLUSTER_NAME, ) # [END howto_operator_gke_describe_job] @@ -114,7 +158,25 @@ project_id=GCP_PROJECT_ID, location=GCP_LOCATION, job_name=job_task_def.output["job_name"], - namespace="default", + namespace=JOB_NAMESPACE, + cluster_name=CLUSTER_NAME, + ) + + describe_job_with_parallelism_task = GKEDescribeJobOperator( + task_id="describe_job_with_parallelism_task", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + job_name=job_task_with_parallelism.output["job_name"], + namespace=JOB_NAMESPACE, + cluster_name=CLUSTER_NAME, + ) + + describe_job_task_def_with_parallelism = GKEDescribeJobOperator( + task_id="describe_job_task_def_with_parallelism", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + job_name=job_task_def_with_parallelism.output["job_name"], + namespace=JOB_NAMESPACE, cluster_name=CLUSTER_NAME, ) @@ -125,7 +187,7 @@ location=GCP_LOCATION, cluster_name=CLUSTER_NAME, name=job_task.output["job_name"], - namespace="default", + namespace=JOB_NAMESPACE, ) # [END howto_operator_gke_suspend_job] @@ -136,7 +198,7 @@ location=GCP_LOCATION, cluster_name=CLUSTER_NAME, name=job_task.output["job_name"], - namespace="default", + namespace=JOB_NAMESPACE, ) # [END howto_operator_gke_resume_job] @@ -156,7 +218,25 @@ project_id=GCP_PROJECT_ID, location=GCP_LOCATION, cluster_name=CLUSTER_NAME, - name=JOB_NAME, + name=JOB_NAME_DEF, + namespace=JOB_NAMESPACE, + ) + + delete_job_with_parallelism = GKEDeleteJobOperator( + task_id="delete_job_with_parallelism", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + name=JOB_NAME_WITH_PARALLELISM, + namespace=JOB_NAMESPACE, + ) + + delete_job_def_with_parallelism = GKEDeleteJobOperator( + task_id="delete_job_def_with_parallelism", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + name=JOB_NAME_DEF_WITH_PARALLELISM, namespace=JOB_NAMESPACE, ) @@ -170,12 +250,17 @@ chain( create_cluster, - [job_task, job_task_def], + [job_task, job_task_def, job_task_with_parallelism, job_task_def_with_parallelism], list_job_task, - [describe_job_task, describe_job_task_def], + [ + describe_job_task, + describe_job_task_def, + describe_job_with_parallelism_task, + describe_job_task_def_with_parallelism, + ], suspend_job, resume_job, - [delete_job, delete_job_def], + [delete_job, delete_job_def, delete_job_with_parallelism, delete_job_def_with_parallelism], delete_cluster, ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py index a8fafd6af5dc1..3c3775025e1cd 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py @@ -862,7 +862,9 @@ def test_execute_deferrable(self, mock_trigger, mock_cluster_hook, mock_fetch_cl mock_pod_metadata = mock.MagicMock() mock_pod_metadata.name = K8S_POD_NAME mock_pod_metadata.namespace = K8S_NAMESPACE - self.operator.pod = mock.MagicMock(metadata=mock_pod_metadata) + self.operator.pods = [ + mock.MagicMock(metadata=mock_pod_metadata), + ] mock_job_metadata = mock.MagicMock() mock_job_metadata.name = K8S_JOB_NAME @@ -880,7 +882,69 @@ def test_execute_deferrable(self, mock_trigger, mock_cluster_hook, mock_fetch_cl ssl_ca_cert=GKE_SSL_CA_CERT, job_name=K8S_JOB_NAME, job_namespace=K8S_NAMESPACE, - pod_name=K8S_POD_NAME, + pod_names=[ + K8S_POD_NAME, + ], + pod_namespace=K8S_NAMESPACE, + base_container_name="base", + gcp_conn_id=TEST_CONN_ID, + poll_interval=10.0, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + get_logs=mock_get_logs, + do_xcom_push=False, + ) + mock_defer.assert_called_once_with( + trigger=mock_trigger.return_value, + method_name="execute_complete", + ) + + @mock.patch(GKE_OPERATORS_PATH.format("GKEStartJobOperator.defer")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEJobTrigger")) + def test_execute_deferrable_with_parallelism( + self, mock_trigger, mock_cluster_hook, mock_fetch_cluster_info, mock_defer + ): + op = GKEStartJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + parallelism=2, + ) + mock_pod_name_1 = K8S_POD_NAME + "-1" + mock_pod_metadata_1 = mock.MagicMock() + mock_pod_metadata_1.name = mock_pod_name_1 + mock_pod_metadata_1.namespace = K8S_NAMESPACE + + mock_pod_name_2 = K8S_POD_NAME + "-2" + mock_pod_metadata_2 = mock.MagicMock() + mock_pod_metadata_2.name = mock_pod_name_2 + mock_pod_metadata_2.namespace = K8S_NAMESPACE + op.pods = [mock.MagicMock(metadata=mock_pod_metadata_1), mock.MagicMock(metadata=mock_pod_metadata_2)] + + mock_job_metadata = mock.MagicMock() + mock_job_metadata.name = K8S_JOB_NAME + mock_job_metadata.namespace = K8S_NAMESPACE + op.job = mock.MagicMock(metadata=mock_job_metadata) + + mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT + mock_get_logs = mock.MagicMock() + op.get_logs = mock_get_logs + + op.execute_deferrable() + + mock_trigger.assert_called_once_with( + cluster_url=GKE_CLUSTER_URL, + ssl_ca_cert=GKE_SSL_CA_CERT, + job_name=K8S_JOB_NAME, + job_namespace=K8S_NAMESPACE, + pod_names=[mock_pod_name_1, mock_pod_name_2], pod_namespace=K8S_NAMESPACE, base_container_name="base", gcp_conn_id=TEST_CONN_ID, diff --git a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py index ffe749ef397bb..3704c5c320624 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py @@ -98,7 +98,9 @@ def job_trigger(): ssl_ca_cert=SSL_CA_CERT, job_name=JOB_NAME, job_namespace=NAMESPACE, - pod_name=POD_NAME, + pod_names=[ + POD_NAME, + ], pod_namespace=NAMESPACE, base_container_name=BASE_CONTAINER_NAME, gcp_conn_id=GCP_CONN_ID, @@ -482,7 +484,9 @@ def test_serialize(self, job_trigger): "ssl_ca_cert": SSL_CA_CERT, "job_name": JOB_NAME, "job_namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "base_container_name": BASE_CONTAINER_NAME, "gcp_conn_id": GCP_CONN_ID, @@ -521,7 +525,9 @@ async def test_run_success(self, mock_hook, job_trigger): { "name": JOB_NAME, "namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "status": "success", "message": "Job completed successfully", @@ -559,7 +565,9 @@ async def test_run_fail(self, mock_hook, job_trigger): { "name": JOB_NAME, "namespace": NAMESPACE, - "pod_name": POD_NAME, + "pod_names": [ + POD_NAME, + ], "pod_namespace": NAMESPACE, "status": "error", "message": "Job failed with error: Error",