Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KubernetesPodOperator new callbacks and allow multiple callbacks #44357

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 81 additions & 7 deletions providers/src/airflow/providers/cncf/kubernetes/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
from __future__ import annotations

from enum import Enum
from typing import Union
from typing import TYPE_CHECKING, Union

import kubernetes.client as k8s
import kubernetes_asyncio.client as async_k8s

if TYPE_CHECKING:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.context import Context

client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]


Expand All @@ -41,7 +45,7 @@ class KubernetesPodOperatorCallback:
"""

@staticmethod
def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOperator, **kwargs) -> None:
"""
Invoke this callback after creating the sync client.

Expand All @@ -50,7 +54,27 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
pass

@staticmethod
def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_manifest_created(
*, pod_request: k8s.V1Pod, mode: str, operator: KubernetesPodOperator, context: Context, **kwargs
) -> None:
"""
Invoke this callback after KPO creates the V1Pod manifest but before the pod is created.

:param pod_request: the kubernetes pod manifest
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_creation(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback after creating the pod.

Expand All @@ -61,7 +85,15 @@ def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs)
pass

@staticmethod
def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_starting(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when the pod starts.

Expand All @@ -72,7 +104,15 @@ def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs)
pass

@staticmethod
def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_completion(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when the pod completes.

Expand All @@ -83,7 +123,34 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg
pass

@staticmethod
def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):
def on_pod_wrapup(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback after all pod completion callbacks but before the pod is deleted.

:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass
Comment on lines +126 to +142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we send the completed pod here. That would require some tracking and filtering to send one. Why can this callback's role be achieved by on_pod_completion?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing code makes a call to find the pod if the callbacks are attached.

pod=self.find_pod(self.pod.metadata.namespace, context=context),

Honestly I'd prefer if I was sending a stale reference and that it was the responsibility of the callback to get an updated pod if it needs it since we're sending the client too. Especially since its possible a callback might not implement the on_pod_completion method. So to maintain existing behaviour I'm getting it once. An alternative might be to test if the on_pod_completion is implemented in the callback and get an updated pod for each callback call, but again this assumes we need an updated pod, which might not be the case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the need for on_pod_wrapup my thought here was that you might have multiple callbacks running before a sidecar container is killed. Rather than attaching mutliple sidecars to the container, you could have a class that attaches a single sidecar and kills it in the on_pod_wrapup, any subclasses of it could pull whatever files they need or run any commands in the pod during the on_pod_completion callback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure if I understand you well here. Are you talking about a specific case of a running sidecar?

Copy link
Author

@johnhoran johnhoran Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm not being clear. The way I've been using this so far is that I have one class that extends off the KubernetesPodOperatorCallback and does the insertion of a sidecar on_pod_manifest_created, killing the sidecar on_pod_wrapup, and some code to ensure that the sidecar is only added/killed once. Then extending off that I have a bunch of other classes that are responsible for doing some actual work with the sidecar, in my case I pull DBT artifacts in the on_pod_completion method, then in on_pod_wrapup they call super().on_pod_wrapup() before extracting dataset events from DBT artifacts and a seperate callback that uploads them to S3.


@staticmethod
def on_pod_cleanup(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
):
"""
Invoke this callback after cleaning/deleting the pod.

Expand All @@ -95,7 +162,14 @@ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):

@staticmethod
def on_operator_resuming(
*, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs
*,
pod: k8s.V1Pod,
event: dict,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when resuming the `KubernetesPodOperator` from deferred state.
Expand Down
112 changes: 79 additions & 33 deletions providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def __init__(
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
callbacks: list[type[KubernetesPodOperatorCallback]]
| type[KubernetesPodOperatorCallback]
| None = None,
progress_callback: Callable[[str], None] | None = None,
logging_interval: int | None = None,
**kwargs,
Expand Down Expand Up @@ -405,7 +407,7 @@ def __init__(

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
self.callbacks = callbacks
self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks]
self._killed: bool = False

@cached_property
Expand Down Expand Up @@ -509,8 +511,9 @@ def hook(self) -> PodOperatorHookProtocol:
@cached_property
def client(self) -> CoreV1Api:
client = self.hook.core_v1_client
if self.callbacks:
self.callbacks.on_sync_client_creation(client=client)

for callback in self.callbacks:
callback.on_sync_client_creation(client=client, operator=self)
return client

def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
Expand Down Expand Up @@ -584,6 +587,10 @@ def execute_sync(self, context: Context):
try:
if self.pod_request_obj is None:
self.pod_request_obj = self.build_pod_request_obj(context)
for callback in self.callbacks:
callback.on_pod_manifest_created(
pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC, context=context, operator=self
)
if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
Expand All @@ -596,25 +603,45 @@ def execute_sync(self, context: Context):

# get remote pod for use in cleanup methods
self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
)
self.await_pod_start(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_starting(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
for callback in self.callbacks:
callback.on_pod_creation(
pod=self.remote_pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
self.await_pod_start(pod=self.pod)
if self.callbacks:
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_starting(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

self.await_pod_completion(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_completion(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
for callback in self.callbacks:
callback.on_pod_wrapup(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

if self.do_xcom_push:
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
Expand All @@ -629,8 +656,14 @@ def execute_sync(self, context: Context):
pod=pod_to_clean,
remote_pod=self.remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)
for callback in self.callbacks:
callback.on_pod_cleanup(
pod=pod_to_clean,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

if self.do_xcom_push:
return result
Expand Down Expand Up @@ -676,11 +709,15 @@ def execute_async(self, context: Context) -> None:
context=context,
)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_creation(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
Expand Down Expand Up @@ -741,10 +778,16 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.callbacks and event["status"] != "running":
self.callbacks.on_operator_resuming(
pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] != "running":
for callback in self.callbacks:
callback.on_operator_resuming(
pod=self.pod,
event=event,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

follow = self.logging_interval is None
last_log_time = event.get("last_log_time")
Expand Down Expand Up @@ -787,9 +830,9 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
except TaskDeferred:
raise
finally:
self._clean(event)
self._clean(event, context)

def _clean(self, event: dict[str, Any]) -> None:
def _clean(self, event: dict[str, Any], context: Context) -> None:
if event["status"] == "running":
return
istio_enabled = self.is_istio_enabled(self.pod)
Expand All @@ -812,6 +855,7 @@ def _clean(self, event: dict[str, Any]) -> None:
self.post_complete_action(
pod=self.pod,
remote_pod=self.pod,
context=context,
)

def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
Expand Down Expand Up @@ -841,14 +885,16 @@ def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime
e if not isinstance(e, ApiException) else e.reason,
)

def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None:
def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None:
"""Actions that must be done after operator finishes logic of the deferrable_execution."""
self.cleanup(
pod=pod,
remote_pod=remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)
for callback in self.callbacks:
callback.on_pod_cleanup(
pod=pod, client=self.client, mode=ExecutionMode.SYNC, operator=self, context=context
)

def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
# Skip cleaning the pod in the following scenarios.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class PodManager(LoggingMixin):
def __init__(
self,
kube_client: client.CoreV1Api,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
callbacks: list[type[KubernetesPodOperatorCallback]] | None = None,
):
"""
Create the launcher.
Expand All @@ -311,7 +311,7 @@ def __init__(
super().__init__()
self._client = kube_client
self._watch = watch.Watch()
self._callbacks = callbacks
self._callbacks = callbacks or []

def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Run POD asynchronously."""
Expand Down Expand Up @@ -446,8 +446,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
progress_callback_lines.append(line)
else: # previous log line is complete
for line in progress_callback_lines:
if self._callbacks:
self._callbacks.progress_callback(
for callback in self._callbacks:
callback.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
Comment on lines +449 to 451
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems ok when callbacks are running in SYNC mode. What about async?
Would probably require some more thinking

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callbacks aren't really implemented for async operation at the moment unfortunately. #35714 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!
Yea in that case, this will do

)
if message_to_log is not None:
Expand All @@ -465,8 +465,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
finally:
# log the last line and update the last_captured_timestamp
for line in progress_callback_lines:
if self._callbacks:
self._callbacks.progress_callback(
for callback in self._callbacks:
callback.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
)
if message_to_log is not None:
Expand Down
Loading