Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
import json
import logging
import os
import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal

from kubernetes.client import BatchV1Api, models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
add_unique_suffix,
Expand Down Expand Up @@ -81,6 +82,17 @@ class KubernetesJobOperator(KubernetesPodOperator):
Used if the parameter `wait_until_job_complete` set True.
:param deferrable: Run operator in the deferrable mode. Note that the parameter
`wait_until_job_complete` must be set True.
:param on_kill_propagation_policy: Whether and how garbage collection will be performed. Default is 'Foreground'.
Acceptable values are:
'Orphan' - orphan the dependents;
'Background' - allow the garbage collector to delete the dependents in the background;
'Foreground' - a cascading policy that deletes all dependents in the foreground.
Default value is 'Foreground'.
:param discover_pods_retry_number: Number of time list_namespaced_pod will be performed to discover
already running pods.
:param unwrap_single: Unwrap single result from the pod. For example, when set to `True` - if the XCom
result should be `['res']`, the final result would be `'res'`. Default is True to support backward
compatibility.
"""

template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields))
Expand All @@ -101,8 +113,12 @@ def __init__(
wait_until_job_complete: bool = False,
job_poll_interval: float = 10,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
on_kill_propagation_policy: Literal["Foreground", "Background", "Orphan"] = "Foreground",
discover_pods_retry_number: int = 3,
unwrap_single: bool = True,
**kwargs,
) -> None:
self._pod = None
super().__init__(**kwargs)
self.job_template_file = job_template_file
self.full_job_spec = full_job_spec
Expand All @@ -119,6 +135,22 @@ def __init__(
self.wait_until_job_complete = wait_until_job_complete
self.job_poll_interval = job_poll_interval
self.deferrable = deferrable
self.on_kill_propagation_policy = on_kill_propagation_policy
self.discover_pods_retry_number = discover_pods_retry_number
self.unwrap_single = unwrap_single

@property
def pod(self):
warnings.warn(
"`pod` parameter is deprecated, please use `pods`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.pods[0] if self.pods else None

@pod.setter
def pod(self, value):
self._pod = value

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -167,35 +199,42 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

self.pod: k8s.V1Pod | None
if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
context=context,
)
self.pods: Sequence[k8s.V1Pod] | None = None
if self.parallelism is None and self.pod is None:
self.pods = [
self.get_or_create_pod(
pod_request_obj=self.pod_request_obj,
context=context,
)
]
else:
self.pods = self.get_pods(pod_request_obj=self.pod_request_obj, context=context)

if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
return

if self.wait_until_job_complete:
if self.do_xcom_push:
self.pod_manager.await_container_completion(
pod=self.pod, container_name=self.base_container_name
)
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
xcom_result = self.extract_xcom(pod=self.pod)
xcom_result = []
for pod in self.pods:
self.pod_manager.await_container_completion(
pod=pod, container_name=self.base_container_name
)
self.pod_manager.await_xcom_sidecar_container_start(pod=pod)
xcom_result.append(self.extract_xcom(pod=pod))
self.job = self.hook.wait_until_job_complete(
job_name=self.job.metadata.name,
namespace=self.job.metadata.namespace,
job_poll_interval=self.job_poll_interval,
)
if self.get_logs:
self.pod_manager.fetch_requested_container_logs(
pod=self.pod,
containers=self.container_logs,
follow_logs=True,
)
for pod in self.pods:
self.pod_manager.fetch_requested_container_logs(
pod=pod,
containers=self.container_logs,
follow_logs=True,
)

ti.xcom_push(key="job", value=self.job.to_dict())
if self.wait_until_job_complete:
Expand All @@ -211,8 +250,8 @@ def execute_deferrable(self):
trigger=KubernetesJobTrigger(
job_name=self.job.metadata.name,
job_namespace=self.job.metadata.namespace,
pod_name=self.pod.metadata.name,
pod_namespace=self.pod.metadata.namespace,
pod_names=[pod.metadata.name for pod in self.pods],
pod_namespace=self.pods[0].metadata.namespace,
base_container_name=self.base_container_name,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
Expand All @@ -232,20 +271,23 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
raise AirflowException(event["message"])

if self.get_logs:
pod_name = event["pod_name"]
pod_namespace = event["pod_namespace"]
self.pod = self.hook.get_pod(pod_name, pod_namespace)
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")
self._write_logs(self.pod)
for pod_name in event["pod_names"]:
pod_namespace = event["pod_namespace"]
pod = self.hook.get_pod(pod_name, pod_namespace)
if not pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")
self._write_logs(pod)

if self.do_xcom_push:
xcom_result = event["xcom_result"]
if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT:
self.log.info("xcom result file is empty.")
return None
self.log.info("xcom result: \n%s", xcom_result)
return json.loads(xcom_result)
xcom_results: list[Any | None] = []
for xcom_result in event["xcom_result"]:
if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT:
self.log.info("xcom result file is empty.")
xcom_results.append(None)
continue
self.log.info("xcom result: \n%s", xcom_result)
xcom_results.append(json.loads(xcom_result))
return xcom_results[0] if self.unwrap_single and len(xcom_results) == 1 else xcom_results

@staticmethod
def deserialize_job_template_file(path: str) -> k8s.V1Job:
Expand Down Expand Up @@ -275,12 +317,11 @@ def on_kill(self) -> None:
kwargs = {
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"propagation_policy": self.on_kill_propagation_policy,
}
if self.termination_grace_period is not None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.job_client.delete_namespaced_job(**kwargs)
if self.pod:
super().on_kill()

def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job:
"""
Expand Down Expand Up @@ -400,6 +441,29 @@ def reconcile_job_specs(

return None

def get_pods(
self, pod_request_obj: k8s.V1Pod, context: Context, *, exclude_checked: bool = True
) -> Sequence[k8s.V1Pod]:
"""Return an already-running pods if exists."""
label_selector = self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
pod_list: Sequence[k8s.V1Pod] = []
retry_number: int = 0

while len(pod_list) != self.parallelism or retry_number <= self.discover_pods_retry_number:
pod_list = self.client.list_namespaced_pod(
namespace=pod_request_obj.metadata.namespace,
label_selector=label_selector,
).items
retry_number += 1

if len(pod_list) == 0:
raise AirflowException(f"No pods running with labels {label_selector}")

for pod_instance in pod_list:
self.log_matching_pod(pod=pod_instance, context=context)

return pod_list


class KubernetesDeleteJobOperator(BaseOperator):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from __future__ import annotations

import asyncio
import warnings
from collections.abc import AsyncIterator
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
Expand All @@ -36,7 +38,8 @@ class KubernetesJobTrigger(BaseTrigger):

:param job_name: The name of the job.
:param job_namespace: The namespace of the job.
:param pod_name: The name of the Pod.
:param pod_name: The name of the Pod. Parameter is deprecated, please use pod_names instead.
:param pod_names: The name of the Pods.
:param pod_namespace: The namespace of the Pod.
:param base_container_name: The name of the base container in the pod.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
Expand All @@ -55,9 +58,10 @@ def __init__(
self,
job_name: str,
job_namespace: str,
pod_name: str,
pod_names: list[str],
pod_namespace: str,
base_container_name: str,
pod_name: str | None = None,
kubernetes_conn_id: str | None = None,
poll_interval: float = 10.0,
cluster_context: str | None = None,
Expand All @@ -69,7 +73,13 @@ def __init__(
super().__init__()
self.job_name = job_name
self.job_namespace = job_namespace
self.pod_name = pod_name
if pod_name is not None:
self._pod_name = pod_name
self.pod_names = [
self.pod_name,
]
else:
self.pod_names = pod_names
self.pod_namespace = pod_namespace
self.base_container_name = base_container_name
self.kubernetes_conn_id = kubernetes_conn_id
Expand All @@ -80,14 +90,23 @@ def __init__(
self.get_logs = get_logs
self.do_xcom_push = do_xcom_push

@property
def pod_name(self):
warnings.warn(
"`pod_name` parameter is deprecated, please use `pod_names`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self._pod_name

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
return (
"airflow.providers.cncf.kubernetes.triggers.job.KubernetesJobTrigger",
{
"job_name": self.job_name,
"job_namespace": self.job_namespace,
"pod_name": self.pod_name,
"pod_names": self.pod_names,
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"kubernetes_conn_id": self.kubernetes_conn_id,
Expand All @@ -102,36 +121,38 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job status and yield a TriggerEvent."""
if self.get_logs or self.do_xcom_push:
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
if self.do_xcom_push:
await self.hook.wait_until_container_complete(
name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name
)
self.log.info("Checking if xcom sidecar container is started.")
await self.hook.wait_until_container_started(
name=self.pod_name,
namespace=self.pod_namespace,
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
)
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
xcom_results = []
for pod_name in self.pod_names:
pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace)
await self.hook.wait_until_container_complete(
name=pod_name, namespace=self.pod_namespace, container_name=self.base_container_name
)
self.log.info("Checking if xcom sidecar container is started.")
await self.hook.wait_until_container_started(
name=pod_name,
namespace=self.pod_namespace,
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
)
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
xcom_results.append(xcom_result)
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
yield TriggerEvent(
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"pod_name": pod.metadata.name if self.get_logs else None,
"pod_namespace": pod.metadata.namespace if self.get_logs else None,
"pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None,
"pod_namespace": self.pod_namespace if self.get_logs else None,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
"xcom_result": xcom_result if self.do_xcom_push else None,
"xcom_result": xcom_results if self.do_xcom_push else None,
}
)

Expand Down
Loading