Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from airflow.models import Connection
from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import should_retry_creation
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
PodOperatorHookProtocol,
container_is_completed,
container_is_running,
)
from airflow.utils import yaml

if TYPE_CHECKING:
Expand Down Expand Up @@ -834,3 +838,37 @@ async def wait_until_job_complete(self, name: str, namespace: str, poll_interval
return job
self.log.info("The job '%s' is incomplete. Sleeping for %i sec.", name, poll_interval)
await asyncio.sleep(poll_interval)

async def wait_until_container_complete(
self, name: str, namespace: str, container_name: str, poll_interval: float = 10
) -> None:
"""
Wait for the given container in the given pod to be completed.

:param name: Name of Pod to fetch.
:param namespace: Namespace of the Pod.
:param container_name: name of the container within the pod to monitor
"""
while True:
pod = await self.get_pod(name=name, namespace=namespace)
if container_is_completed(pod=pod, container_name=container_name):
break
self.log.info("Waiting for container '%s' state to be completed", container_name)
await asyncio.sleep(poll_interval)

async def wait_until_container_started(
self, name: str, namespace: str, container_name: str, poll_interval: float = 10
) -> None:
"""
Wait for the given container in the given pod to be started.

:param name: Name of Pod to fetch.
:param namespace: Namespace of the Pod.
:param container_name: name of the container within the pod to monitor
"""
while True:
pod = await self.get_pod(name=name, namespace=namespace)
if container_is_running(pod=pod, container_name=container_name):
break
self.log.info("Waiting for container '%s' state to be running", container_name)
await asyncio.sleep(poll_interval)
55 changes: 53 additions & 2 deletions airflow/providers/cncf/kubernetes/operators/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import copy
import json
import logging
import os
from functools import cached_property
Expand All @@ -39,6 +40,7 @@
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects
from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, PodNotFoundException
from airflow.utils import yaml
from airflow.utils.context import Context

Expand Down Expand Up @@ -135,7 +137,7 @@ def hook(self) -> KubernetesHook:
return hook

@cached_property
def client(self) -> BatchV1Api:
def job_client(self) -> BatchV1Api:
return self.hook.batch_v1_client

def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job:
Expand All @@ -150,6 +152,11 @@ def execute(self, context: Context):
"Deferrable mode is available only with parameter `wait_until_job_complete=True`. "
"Please, set it up."
)
if (self.get_logs or self.do_xcom_push) and not self.wait_until_job_complete:
self.log.warning(
"Getting Logs and pushing to XCom are available only with parameter `wait_until_job_complete=True`. "
"Please, set it up."
)
self.job_request_obj = self.build_job_request_obj(context)
self.job = self.create_job( # must set `self.job` for `on_kill`
job_request_obj=self.job_request_obj
Expand All @@ -159,34 +166,59 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
context=context,
)
Comment on lines +169 to +173
Copy link

@NilsIrl NilsIrl Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work if the pod wasn't already created by the time this is reached? Won't 2 pods eventually get created (when the job eventually creates the pod)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. I thought the K8s Job will manage the pod creations, why would we need creating the pod manually again? I'm actually running into bugs where I set parallelism to more than 1 and this line triggers self.find_pods() which triggers

raise AirflowException(f"More than one pod running with labels {label_selector}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steinwaywhw can you explain your problem more? Would you like to take a stab at fixing it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @steinwaywhw,
It's a known issue #44994 and the fix was already prepared by this PR #49899


if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
return

if self.wait_until_job_complete:
if self.do_xcom_push:
self.pod_manager.await_container_completion(
pod=self.pod, container_name=self.base_container_name
)
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
xcom_result = self.extract_xcom(pod=self.pod)
self.job = self.hook.wait_until_job_complete(
job_name=self.job.metadata.name,
namespace=self.job.metadata.namespace,
job_poll_interval=self.job_poll_interval,
)
if self.get_logs:
self.pod_manager.fetch_requested_container_logs(
pod=self.pod,
containers=self.container_logs,
follow_logs=True,
)

ti.xcom_push(key="job", value=self.job.to_dict())
if self.wait_until_job_complete:
if error_message := self.hook.is_job_failed(job=self.job):
raise AirflowException(
f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'"
)
if self.do_xcom_push:
return xcom_result

def execute_deferrable(self):
self.defer(
trigger=KubernetesJobTrigger(
job_name=self.job.metadata.name, # type: ignore[union-attr]
job_namespace=self.job.metadata.namespace, # type: ignore[union-attr]
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
base_container_name=self.base_container_name,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_file=self.config_file,
in_cluster=self.in_cluster,
poll_interval=self.job_poll_interval,
get_logs=self.get_logs,
do_xcom_push=self.do_xcom_push,
),
method_name="execute_complete",
)
Expand All @@ -197,6 +229,22 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
if event["status"] == "error":
raise AirflowException(event["message"])

if self.get_logs:
pod_name = event["pod_name"]
pod_namespace = event["pod_namespace"]
self.pod = self.hook.get_pod(pod_name, pod_namespace)
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")
self._write_logs(self.pod)

if self.do_xcom_push:
xcom_result = event["xcom_result"]
if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT:
self.log.info("xcom result file is empty.")
return None
self.log.info("xcom result: \n%s", xcom_result)
return json.loads(xcom_result)

@staticmethod
def deserialize_job_template_file(path: str) -> k8s.V1Job:
"""
Expand Down Expand Up @@ -229,7 +277,9 @@ def on_kill(self) -> None:
}
if self.termination_grace_period is not None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.client.delete_namespaced_job(**kwargs)
self.job_client.delete_namespaced_job(**kwargs)
if self.pod:
super().on_kill()

def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job:
"""
Expand All @@ -254,6 +304,7 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job:
metadata=pod_template.metadata,
spec=pod_template.spec,
)
self.pod_request_obj = pod_template

job = k8s.V1Job(
api_version="batch/v1",
Expand Down
55 changes: 54 additions & 1 deletion airflow/providers/cncf/kubernetes/triggers/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.
from __future__ import annotations

import asyncio
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
Expand All @@ -32,32 +35,49 @@ class KubernetesJobTrigger(BaseTrigger):

:param job_name: The name of the job.
:param job_namespace: The namespace of the job.
:param pod_name: The name of the Pod.
:param pod_namespace: The namespace of the Pod.
:param base_container_name: The name of the base container in the pod.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param in_cluster: run kubernetes client with in_cluster configuration.
:param get_logs: get the stdout of the base container as logs of the tasks.
:param do_xcom_push: If True, the content of the file
/airflow/xcom/return.json in the container will also be pushed to an
XCom when the container completes.
"""

def __init__(
self,
job_name: str,
job_namespace: str,
pod_name: str,
pod_namespace: str,
base_container_name: str,
kubernetes_conn_id: str | None = None,
poll_interval: float = 10.0,
cluster_context: str | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
get_logs: bool = True,
do_xcom_push: bool = False,
):
super().__init__()
self.job_name = job_name
self.job_namespace = job_namespace
self.pod_name = pod_name
self.pod_namespace = pod_namespace
self.base_container_name = base_container_name
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster
self.get_logs = get_logs
self.do_xcom_push = do_xcom_push

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
Expand All @@ -66,28 +86,51 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
{
"job_name": self.job_name,
"job_namespace": self.job_namespace,
"pod_name": self.pod_name,
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"kubernetes_conn_id": self.kubernetes_conn_id,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
"get_logs": self.get_logs,
"do_xcom_push": self.do_xcom_push,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current job status and yield a TriggerEvent."""
if self.get_logs or self.do_xcom_push:
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
if self.do_xcom_push:
await self.hook.wait_until_container_complete(
name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name
)
self.log.info("Checking if xcom sidecar container is started.")
await self.hook.wait_until_container_started(
name=self.pod_name,
namespace=self.pod_namespace,
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
)
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
yield TriggerEvent(
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"pod_name": pod.metadata.name if self.get_logs else None,
"pod_namespace": pod.metadata.namespace if self.get_logs else None,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
"xcom_result": xcom_result if self.do_xcom_push else None,
}
)

Expand All @@ -99,3 +142,13 @@ def hook(self) -> AsyncKubernetesHook:
config_file=self.config_file,
cluster_context=self.cluster_context,
)

@cached_property
def pod_manager(self) -> PodManager:
sync_hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
return PodManager(kube_client=sync_hook.core_v1_client)
13 changes: 13 additions & 0 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,13 +943,26 @@ def execute_deferrable(self):
ssl_ca_cert=self._ssl_ca_cert,
job_name=self.job.metadata.name, # type: ignore[union-attr]
job_namespace=self.job.metadata.namespace, # type: ignore[union-attr]
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
base_container_name=self.base_container_name,
gcp_conn_id=self.gcp_conn_id,
poll_interval=self.job_poll_interval,
impersonation_chain=self.impersonation_chain,
get_logs=self.get_logs,
do_xcom_push=self.do_xcom_push,
),
method_name="execute_complete",
kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert},
)

def execute_complete(self, context: Context, event: dict, **kwargs):
# It is required for hook to be initialized
self._cluster_url = kwargs["cluster_url"]
self._ssl_ca_cert = kwargs["ssl_ca_cert"]

return super().execute_complete(context, event)


class GKEDescribeJobOperator(GoogleCloudBaseOperator):
"""
Expand Down
Loading