From 9828750fd4d79f9fa44181733065616a9523324d Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 25 Nov 2024 14:45:18 +0000 Subject: [PATCH 01/12] callbacks list --- .../providers/cncf/kubernetes/callbacks.py | 19 +++++ .../cncf/kubernetes/operators/pod.py | 78 ++++++++++++------- .../cncf/kubernetes/utils/pod_manager.py | 12 +-- .../cncf/kubernetes/operators/test_pod.py | 8 ++ .../tests/cncf/kubernetes/test_callbacks.py | 8 ++ .../cncf/kubernetes/utils/test_pod_manager.py | 2 +- 6 files changed, 91 insertions(+), 36 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index eb971b4a30469..5fd826f957a95 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -49,6 +49,15 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: """ pass + @staticmethod + def on_manifest_finalization(*, pod_request: k8s.V1Pod, mode: str, **kwargs) -> None: + """ + Invoke this callback after KPO creates the V1Pod manifest but before the pod is created. + + :param pod_request: the kubernetes pod manifest + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + @staticmethod def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: """ @@ -82,6 +91,16 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg """ pass + @staticmethod + def on_pod_wrapup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """ + Invoked after all pod completion callbacks but before the pod is deleted. + + :param pod: the completed pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + @staticmethod def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): """ diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index cbafc72f3455c..3a01c5a678314 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -315,7 +315,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: type[KubernetesPodOperatorCallback] | None = None, + callbacks: list[KubernetesPodOperatorCallback] | KubernetesPodOperatorCallback | None = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, **kwargs, @@ -403,7 +403,11 @@ def __init__( self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback - self.callbacks = callbacks + self.callbacks = ( + [] if not callbacks + else callbacks if isinstance(callbacks, list) + else [callbacks] + ) self._killed: bool = False @cached_property @@ -507,8 +511,9 @@ def hook(self) -> PodOperatorHookProtocol: @cached_property def client(self) -> CoreV1Api: client = self.hook.core_v1_client - if self.callbacks: - self.callbacks.on_sync_client_creation(client=client) + + for callback in self.callbacks: + callback.on_sync_client_creation(client=client) return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: @@ -582,6 +587,8 @@ def execute_sync(self, context: Context): try: if self.pod_request_obj is None: self.pod_request_obj = self.build_pod_request_obj(context) + for callback in self.callbacks: + callback.on_manifest_finalization(pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC) if self.pod is None: self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` pod_request_obj=self.pod_request_obj, @@ -594,25 +601,35 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) - if self.callbacks: - self.callbacks.on_pod_creation( + for callback in self.callbacks: + callback.on_pod_creation( pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC ) self.await_pod_start(pod=self.pod) if self.callbacks: - self.callbacks.on_pod_starting( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod = self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_starting( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + ) self.await_pod_completion(pod=self.pod) if self.callbacks: - self.callbacks.on_pod_completion( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod=self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_completion( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + ) + for callback in self.callbacks: + callback.on_pod_wrapup( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.do_xcom_push: self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) @@ -627,8 +644,8 @@ def execute_sync(self, context: Context): pod=pod_to_clean, remote_pod=self.remote_pod, ) - if self.callbacks: - self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) + for callback in self.callbacks: + callback.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) if self.do_xcom_push: return result @@ -674,11 +691,13 @@ def execute_async(self, context: Context) -> None: context=context, ) if self.callbacks: - self.callbacks.on_pod_creation( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod=self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_creation( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) @@ -739,10 +758,11 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: if not self.pod: raise PodNotFoundException("Could not find pod after resuming from deferral") - if self.callbacks and event["status"] != "running": - self.callbacks.on_operator_resuming( - pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC - ) + if event["status"] != "running": + for callback in self.callbacks: + callback.on_operator_resuming( + pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC + ) follow = self.logging_interval is None last_log_time = event.get("last_log_time") @@ -845,8 +865,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None: pod=pod, remote_pod=remote_pod, ) - if self.callbacks: - self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) + for callback in self.callbacks: + callback.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): # Skip cleaning the pod in the following scenarios. diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index e123ef0b0d284..ab321d1e1db37 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -300,7 +300,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: type[KubernetesPodOperatorCallback] | None = None, + callbacks: list[KubernetesPodOperatorCallback] | None = None, ): """ Create the launcher. @@ -311,7 +311,7 @@ def __init__( super().__init__() self._client = kube_client self._watch = watch.Watch() - self._callbacks = callbacks + self._callbacks = callbacks or [] def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: """Run POD asynchronously.""" @@ -446,8 +446,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None progress_callback_lines.append(line) else: # previous log line is complete for line in progress_callback_lines: - if self._callbacks: - self._callbacks.progress_callback( + for callback in self._callbacks: + callback.progress_callback( line=line, client=self._client, mode=ExecutionMode.SYNC ) if message_to_log is not None: @@ -462,8 +462,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None finally: # log the last line and update the last_captured_timestamp for line in progress_callback_lines: - if self._callbacks: - self._callbacks.progress_callback( + for callback in self._callbacks: + callback.progress_callback( line=line, client=self._client, mode=ExecutionMode.SYNC ) if message_to_log is not None: diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index d6b7378c80c80..22709140d0859 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1519,6 +1519,7 @@ def test_execute_sync_callbacks(self, find_pod_mock): mock_callbacks.on_sync_client_creation.assert_called_once() assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + # check on_pod_creation callback mock_callbacks.on_pod_creation.assert_called_once() assert mock_callbacks.on_pod_creation.call_args.kwargs == { @@ -1543,6 +1544,13 @@ def test_execute_sync_callbacks(self, find_pod_mock): "pod": found_pods[2], } + mock_callbacks.on_pod_wrapup.assert_called_once() + assert mock_callbacks.on_pod_wrapup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + } + # check on_pod_cleanup callback mock_callbacks.on_pod_cleanup.assert_called_once() assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/tests/cncf/kubernetes/test_callbacks.py index 2757b8296a7d0..30f69751a37f1 100644 --- a/providers/tests/cncf/kubernetes/test_callbacks.py +++ b/providers/tests/cncf/kubernetes/test_callbacks.py @@ -63,3 +63,11 @@ def on_operator_resuming(*args, **kwargs) -> None: @staticmethod def progress_callback(*args, **kwargs) -> None: MockWrapper.mock_callbacks.progress_callback(*args, **kwargs) + + @staticmethod + def on_manifest_finalization(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_manifest_finalization(*args, **kwargs) + + @staticmethod + def on_pod_wrapup(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_wrapup(*args, **kwargs) \ No newline at end of file diff --git a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py index d220befc2191e..7db2cb78637d2 100644 --- a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py @@ -52,7 +52,7 @@ def setup_method(self): self.mock_kube_client = mock.Mock() self.pod_manager = PodManager( kube_client=self.mock_kube_client, - callbacks=MockKubernetesPodOperatorCallback, + callbacks=[MockKubernetesPodOperatorCallback], ) def test_read_pod_logs_successfully_returns_logs(self): From f9b54349366698e208588cbc4c929171ea31ec4c Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 25 Nov 2024 14:57:05 +0000 Subject: [PATCH 02/12] pass --- providers/src/airflow/providers/cncf/kubernetes/callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index 5fd826f957a95..b65a1707681a9 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -57,6 +57,7 @@ def on_manifest_finalization(*, pod_request: k8s.V1Pod, mode: str, **kwargs) -> :param pod_request: the kubernetes pod manifest :param mode: the current execution mode, it's one of (`sync`, `async`). """ + pass @staticmethod def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: @@ -100,6 +101,7 @@ def on_pod_wrapup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) - :param client: the Kubernetes client that can be used in the callback. :param mode: the current execution mode, it's one of (`sync`, `async`). """ + pass @staticmethod def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): From 9ac5625cb89bf65c6a3520411d957fe9c4ad6a14 Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 25 Nov 2024 15:47:52 +0000 Subject: [PATCH 03/12] add test --- providers/tests/cncf/kubernetes/operators/test_pod.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 22709140d0859..63c73e4cc0e0d 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1519,6 +1519,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): mock_callbacks.on_sync_client_creation.assert_called_once() assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + # check on_manifest_finalization callback + mock_callbacks.on_manifest_finalization.assert_called_once() # check on_pod_creation callback mock_callbacks.on_pod_creation.assert_called_once() From 28a5e8a5b7cb274e3fec069ea41fe96e8b3484fe Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 25 Nov 2024 15:50:21 +0000 Subject: [PATCH 04/12] fmt --- .../providers/cncf/kubernetes/operators/pod.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 3a01c5a678314..a27aca7ee5c81 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -403,11 +403,7 @@ def __init__( self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback - self.callbacks = ( - [] if not callbacks - else callbacks if isinstance(callbacks, list) - else [callbacks] - ) + self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks] self._killed: bool = False @cached_property @@ -602,9 +598,7 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) for callback in self.callbacks: - callback.on_pod_creation( - pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC - ) + callback.on_pod_creation(pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC) self.await_pod_start(pod=self.pod) if self.callbacks: pod = self.find_pod(self.pod.metadata.namespace, context=context) @@ -617,7 +611,7 @@ def execute_sync(self, context: Context): self.await_pod_completion(pod=self.pod) if self.callbacks: - pod=self.find_pod(self.pod.metadata.namespace, context=context) + pod = self.find_pod(self.pod.metadata.namespace, context=context) for callback in self.callbacks: callback.on_pod_completion( pod=pod, @@ -691,7 +685,7 @@ def execute_async(self, context: Context) -> None: context=context, ) if self.callbacks: - pod=self.find_pod(self.pod.metadata.namespace, context=context) + pod = self.find_pod(self.pod.metadata.namespace, context=context) for callback in self.callbacks: callback.on_pod_creation( pod=pod, From 08a4a4ba8df49f2f2709c1fdabe2f374bc21dd07 Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 25 Nov 2024 17:26:29 +0000 Subject: [PATCH 05/12] pass context --- .../providers/cncf/kubernetes/callbacks.py | 72 ++++++++++++++++--- .../cncf/kubernetes/operators/pod.py | 52 ++++++++++---- .../cncf/kubernetes/operators/test_pod.py | 39 +++++++--- 3 files changed, 132 insertions(+), 31 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index b65a1707681a9..4eef4bce527fa 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -22,6 +22,13 @@ import kubernetes.client as k8s import kubernetes_asyncio.client as async_k8s + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator + from airflow.utils.context import Context + client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api] @@ -41,7 +48,7 @@ class KubernetesPodOperatorCallback: """ @staticmethod - def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: + def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOperator, **kwargs) -> None: """ Invoke this callback after creating the sync client. @@ -50,7 +57,9 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: pass @staticmethod - def on_manifest_finalization(*, pod_request: k8s.V1Pod, mode: str, **kwargs) -> None: + def on_manifest_finalization( + *, pod_request: k8s.V1Pod, mode: str, operator: KubernetesPodOperator, context: Context, **kwargs + ) -> None: """ Invoke this callback after KPO creates the V1Pod manifest but before the pod is created. @@ -60,7 +69,15 @@ def on_manifest_finalization(*, pod_request: k8s.V1Pod, mode: str, **kwargs) -> pass @staticmethod - def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_creation( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback after creating the pod. @@ -71,7 +88,15 @@ def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) pass @staticmethod - def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_starting( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback when the pod starts. @@ -82,7 +107,15 @@ def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) pass @staticmethod - def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_completion( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback when the pod completes. @@ -93,7 +126,15 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg pass @staticmethod - def on_pod_wrapup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_wrapup( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoked after all pod completion callbacks but before the pod is deleted. @@ -104,7 +145,15 @@ def on_pod_wrapup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) - pass @staticmethod - def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): + def on_pod_cleanup( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ): """ Invoke this callback after cleaning/deleting the pod. @@ -116,7 +165,14 @@ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): @staticmethod def on_operator_resuming( - *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs + *, + pod: k8s.V1Pod, + event: dict, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, ) -> None: """ Invoke this callback when resuming the `KubernetesPodOperator` from deferred state. diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index a27aca7ee5c81..99051621dd82e 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -387,9 +387,7 @@ def __init__( self.skip_on_exit_code = ( skip_on_exit_code if isinstance(skip_on_exit_code, Container) - else [skip_on_exit_code] - if skip_on_exit_code is not None - else [] + else [skip_on_exit_code] if skip_on_exit_code is not None else [] ) self.deferrable = deferrable self.poll_interval = poll_interval @@ -509,7 +507,7 @@ def client(self) -> CoreV1Api: client = self.hook.core_v1_client for callback in self.callbacks: - callback.on_sync_client_creation(client=client) + callback.on_sync_client_creation(client=client, operator=self) return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: @@ -584,7 +582,9 @@ def execute_sync(self, context: Context): if self.pod_request_obj is None: self.pod_request_obj = self.build_pod_request_obj(context) for callback in self.callbacks: - callback.on_manifest_finalization(pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC) + callback.on_manifest_finalization( + pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC, context=context, operator=self + ) if self.pod is None: self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` pod_request_obj=self.pod_request_obj, @@ -598,7 +598,13 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) for callback in self.callbacks: - callback.on_pod_creation(pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC) + callback.on_pod_creation( + pod=self.remote_pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) self.await_pod_start(pod=self.pod) if self.callbacks: pod = self.find_pod(self.pod.metadata.namespace, context=context) @@ -607,6 +613,8 @@ def execute_sync(self, context: Context): pod=pod, client=self.client, mode=ExecutionMode.SYNC, + context=context, + operator=self, ) self.await_pod_completion(pod=self.pod) @@ -617,12 +625,16 @@ def execute_sync(self, context: Context): pod=pod, client=self.client, mode=ExecutionMode.SYNC, + context=context, + operator=self, ) for callback in self.callbacks: callback.on_pod_wrapup( pod=pod, client=self.client, mode=ExecutionMode.SYNC, + context=context, + operator=self, ) if self.do_xcom_push: @@ -639,7 +651,13 @@ def execute_sync(self, context: Context): remote_pod=self.remote_pod, ) for callback in self.callbacks: - callback.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) + callback.on_pod_cleanup( + pod=pod_to_clean, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) if self.do_xcom_push: return result @@ -691,6 +709,8 @@ def execute_async(self, context: Context) -> None: pod=pod, client=self.client, mode=ExecutionMode.SYNC, + context=context, + operator=self, ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) @@ -755,7 +775,12 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: if event["status"] != "running": for callback in self.callbacks: callback.on_operator_resuming( - pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC + pod=self.pod, + event=event, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, ) follow = self.logging_interval is None @@ -799,9 +824,9 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: except TaskDeferred: raise finally: - self._clean(event) + self._clean(event, context) - def _clean(self, event: dict[str, Any]) -> None: + def _clean(self, event: dict[str, Any], context: Context) -> None: if event["status"] == "running": return istio_enabled = self.is_istio_enabled(self.pod) @@ -824,6 +849,7 @@ def _clean(self, event: dict[str, Any]) -> None: self.post_complete_action( pod=self.pod, remote_pod=self.pod, + context=context, ) def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None: @@ -853,14 +879,16 @@ def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime e if not isinstance(e, ApiException) else e.reason, ) - def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None: + def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None: """Actions that must be done after operator finishes logic of the deferrable_execution.""" self.cleanup( pod=pod, remote_pod=remote_pod, ) for callback in self.callbacks: - callback.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) + callback.on_pod_cleanup( + pod=pod, client=self.client, mode=ExecutionMode.SYNC, operator=self, context=context + ) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): # Skip cleaning the pod in the following scenarios. diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 63c73e4cc0e0d..bbc305a059446 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -47,6 +47,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodLoggingStatus, PodPhase from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils import timezone +from airflow.utils.context import Context from airflow.utils.session import create_session from airflow.utils.types import DagRunType @@ -196,7 +197,7 @@ def test_templates(self, create_task_instance_of_operator, session): assert dag_id == rendered.volumes[0].name assert dag_id == rendered.volumes[0].config_map.name - def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V1Pod: + def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> tuple[k8s.V1Pod, Context]: with self.dag_maker(dag_id="dag") as dag: operator.dag = dag @@ -211,7 +212,7 @@ def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V remote_pod_mock.status.phase = "Succeeded" self.await_pod_mock.return_value = remote_pod_mock operator.execute(context=context) - return self.await_start_mock.call_args.kwargs["pod"] + return self.await_start_mock.call_args.kwargs["pod"], context def sanitize_for_serialization(self, obj): return ApiClient().sanitize_for_serialization(obj) @@ -357,7 +358,7 @@ def test_labels(self, hook_mock, in_cluster): in_cluster=in_cluster, do_xcom_push=False, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) assert pod.metadata.labels == { "foo": "bar", "dag_id": "dag", @@ -1247,7 +1248,7 @@ def test_push_xcom_pod_info( do_xcom_push=do_xcom_push, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_name") pod_namespace = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_namespace") assert pod_name == pod.metadata.name @@ -1475,7 +1476,7 @@ def test_get_logs_but_not_for_base_container( remote_pod_mock = MagicMock() remote_pod_mock.status.phase = "Succeeded" self.await_pod_mock.return_value = remote_pod_mock - pod = self.run_pod(k) + pod, _ = self.run_pod(k) # check that the base container is not included in the logs mock_fetch_log.assert_called_once_with(pod=pod, containers=["some_init_container"], follow_logs=True) @@ -1513,11 +1514,11 @@ def test_execute_sync_callbacks(self, find_pod_mock): do_xcom_push=False, callbacks=MockKubernetesPodOperatorCallback, ) - self.run_pod(k) + _, context = self.run_pod(k) # check on_sync_client_creation callback mock_callbacks.on_sync_client_creation.assert_called_once() - assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client, "operator": k} # check on_manifest_finalization callback mock_callbacks.on_manifest_finalization.assert_called_once() @@ -1528,6 +1529,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[0], + "operator": k, + "context": context, } # check on_pod_starting callback @@ -1536,6 +1539,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[1], + "operator": k, + "context": context, } # check on_pod_completion callback @@ -1544,6 +1549,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[2], + "operator": k, + "context": context, } mock_callbacks.on_pod_wrapup.assert_called_once() @@ -1551,6 +1558,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[2], + "operator": k, + "context": context, } # check on_pod_cleanup callback @@ -1559,6 +1568,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": k.pod, + "operator": k, + "context": context, } @patch(HOOK_CLASS, new=MagicMock) @@ -1587,8 +1598,10 @@ def test_execute_async_callbacks(self): do_xcom_push=False, callbacks=MockKubernetesPodOperatorCallback, ) + context = create_context(k) + k.trigger_reentry( - context=create_context(k), + context=context, event={ "status": "success", "message": TEST_SUCCESS_MESSAGE, @@ -1603,6 +1616,8 @@ def test_execute_async_callbacks(self): "client": k.client, "mode": ExecutionMode.SYNC, "pod": remote_pod_mock, + "operator": k, + "context": context, } # check on_pod_cleanup callback @@ -1611,6 +1626,8 @@ def test_execute_async_callbacks(self): "client": k.client, "mode": ExecutionMode.SYNC, "pod": remote_pod_mock, + "operator": k, + "context": context, } @pytest.mark.parametrize("get_logs", [True, False]) @@ -1621,7 +1638,7 @@ def test_await_container_completion_refreshes_properties_on_exception( self, mock_read_pod, mock_await_container_completion, fetch_requested_container_logs, get_logs ): k = KubernetesPodOperator(task_id="task", get_logs=get_logs) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) client, hook, pod_manager = k.client, k.hook, k.pod_manager # no exception doesn't update properties @@ -1654,7 +1671,7 @@ def test_await_container_completion_raises_unauthorized_if_credentials_still_inv self, mock_read_pod, mock_await_container_completion ): k = KubernetesPodOperator(task_id="task", get_logs=False) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) client, hook, pod_manager = k.client, k.hook, k.pod_manager mock_await_container_completion.side_effect = [ApiException(status=401)] @@ -1687,7 +1704,7 @@ def test_await_container_completion_retries_on_specific_exception( task_id="task", get_logs=False, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) mock_await_container_completion.side_effect = side_effect if expect_exc: k.await_pod_completion(pod) From 9847da9aa6d974ff95b2731e1eab7d1583cdae74 Mon Sep 17 00:00:00 2001 From: John Horan Date: Thu, 28 Nov 2024 16:55:25 +0000 Subject: [PATCH 06/12] only static callbacks allowed --- .../src/airflow/providers/cncf/kubernetes/operators/pod.py | 2 +- .../src/airflow/providers/cncf/kubernetes/utils/pod_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 99051621dd82e..fbe67cff43a98 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -315,7 +315,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: list[KubernetesPodOperatorCallback] | KubernetesPodOperatorCallback | None = None, + callbacks: list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, **kwargs, diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index ab321d1e1db37..ef34e6ea21796 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -300,7 +300,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: list[KubernetesPodOperatorCallback] | None = None, + callbacks: list[type[KubernetesPodOperatorCallback]] | None = None, ): """ Create the launcher. From b888ba2145d6dc33b7a46c41637baeba45644904 Mon Sep 17 00:00:00 2001 From: knasher Date: Thu, 28 Nov 2024 21:12:24 +0000 Subject: [PATCH 07/12] fix lint error --- .../src/airflow/providers/cncf/kubernetes/callbacks.py | 7 ++----- .../airflow/providers/cncf/kubernetes/operators/pod.py | 8 ++++++-- providers/tests/cncf/kubernetes/operators/test_pod.py | 5 ++++- providers/tests/cncf/kubernetes/test_callbacks.py | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index 4eef4bce527fa..9d0e58680af9b 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -17,14 +17,11 @@ from __future__ import annotations from enum import Enum -from typing import Union +from typing import TYPE_CHECKING, Union import kubernetes.client as k8s import kubernetes_asyncio.client as async_k8s - -from typing import TYPE_CHECKING - if TYPE_CHECKING: from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.utils.context import Context @@ -136,7 +133,7 @@ def on_pod_wrapup( **kwargs, ) -> None: """ - Invoked after all pod completion callbacks but before the pod is deleted. + Invoke this callback after all pod completion callbacks but before the pod is deleted. :param pod: the completed pod. :param client: the Kubernetes client that can be used in the callback. diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index fbe67cff43a98..2fd4511a83dac 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -315,7 +315,9 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None = None, + callbacks: list[type[KubernetesPodOperatorCallback]] + | type[KubernetesPodOperatorCallback] + | None = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, **kwargs, @@ -387,7 +389,9 @@ def __init__( self.skip_on_exit_code = ( skip_on_exit_code if isinstance(skip_on_exit_code, Container) - else [skip_on_exit_code] if skip_on_exit_code is not None else [] + else [skip_on_exit_code] + if skip_on_exit_code is not None + else [] ) self.deferrable = deferrable self.poll_interval = poll_interval diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index bbc305a059446..29e9e53d9b57f 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -20,6 +20,7 @@ import re from contextlib import contextmanager, nullcontext from io import BytesIO +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, mock_open, patch @@ -47,7 +48,9 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodLoggingStatus, PodPhase from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils import timezone -from airflow.utils.context import Context + +if TYPE_CHECKING: + from airflow.utils.context import Context from airflow.utils.session import create_session from airflow.utils.types import DagRunType diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/tests/cncf/kubernetes/test_callbacks.py index 30f69751a37f1..a4faa21ceba83 100644 --- a/providers/tests/cncf/kubernetes/test_callbacks.py +++ b/providers/tests/cncf/kubernetes/test_callbacks.py @@ -70,4 +70,4 @@ def on_manifest_finalization(*args, **kwargs) -> None: @staticmethod def on_pod_wrapup(*args, **kwargs) -> None: - MockWrapper.mock_callbacks.on_pod_wrapup(*args, **kwargs) \ No newline at end of file + MockWrapper.mock_callbacks.on_pod_wrapup(*args, **kwargs) From bd5cfebed22ccad42765a7d962e918e3cb25b6bf Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 16 Dec 2024 13:10:17 +0000 Subject: [PATCH 08/12] rename --- providers/src/airflow/providers/cncf/kubernetes/callbacks.py | 2 +- .../src/airflow/providers/cncf/kubernetes/operators/pod.py | 2 +- providers/tests/cncf/kubernetes/operators/test_pod.py | 4 ++-- providers/tests/cncf/kubernetes/test_callbacks.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index 9d0e58680af9b..c6871eee69c09 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -54,7 +54,7 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOpe pass @staticmethod - def on_manifest_finalization( + def on_pod_manifest_created( *, pod_request: k8s.V1Pod, mode: str, operator: KubernetesPodOperator, context: Context, **kwargs ) -> None: """ diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 6a913df2eb611..d673efee46eca 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -588,7 +588,7 @@ def execute_sync(self, context: Context): if self.pod_request_obj is None: self.pod_request_obj = self.build_pod_request_obj(context) for callback in self.callbacks: - callback.on_manifest_finalization( + callback.on_pod_manifest_created( pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC, context=context, operator=self ) if self.pod is None: diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 155ff2638af93..879fed343cbd3 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1523,8 +1523,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): mock_callbacks.on_sync_client_creation.assert_called_once() assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client, "operator": k} - # check on_manifest_finalization callback - mock_callbacks.on_manifest_finalization.assert_called_once() + # check on_pod_manifest_created callback + mock_callbacks.on_pod_manifest_created.assert_called_once() # check on_pod_creation callback mock_callbacks.on_pod_creation.assert_called_once() diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/tests/cncf/kubernetes/test_callbacks.py index a4faa21ceba83..1e7b541ba03a1 100644 --- a/providers/tests/cncf/kubernetes/test_callbacks.py +++ b/providers/tests/cncf/kubernetes/test_callbacks.py @@ -65,8 +65,8 @@ def progress_callback(*args, **kwargs) -> None: MockWrapper.mock_callbacks.progress_callback(*args, **kwargs) @staticmethod - def on_manifest_finalization(*args, **kwargs) -> None: - MockWrapper.mock_callbacks.on_manifest_finalization(*args, **kwargs) + def on_pod_manifest_created(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_manifest_created(*args, **kwargs) @staticmethod def on_pod_wrapup(*args, **kwargs) -> None: From 078bc94c480dd197c42540298b2b6638d7770089 Mon Sep 17 00:00:00 2001 From: John Horan Date: Mon, 16 Dec 2024 14:14:12 +0000 Subject: [PATCH 09/12] add test --- .../cncf/kubernetes/operators/test_pod.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 879fed343cbd3..994ea89204637 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1575,6 +1575,93 @@ def test_execute_sync_callbacks(self, find_pod_mock): "context": context, } + @patch(HOOK_CLASS, new=MagicMock) + @patch(KUB_OP_PATH.format("find_pod")) + def test_execute_sync_multiple_callbacks(self, find_pod_mock): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from providers.tests.cncf.kubernetes.test_callbacks import ( + MockKubernetesPodOperatorCallback, + MockWrapper, + ) + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + found_pods = [MagicMock(), MagicMock(), MagicMock()] + find_pod_mock.side_effect = [None] + found_pods + + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=[MockKubernetesPodOperatorCallback, MockKubernetesPodOperatorCallback] + ) + _, context = self.run_pod(k) + + # check on_sync_client_creation callback + assert mock_callbacks.on_sync_client_creation.call_count == 2 + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client, "operator": k} + + # check on_pod_manifest_created callback + assert mock_callbacks.on_pod_manifest_created.call_count == 2 + + # check on_pod_creation callback + assert mock_callbacks.on_pod_creation.call_count == 2 + assert mock_callbacks.on_pod_creation.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[0], + "operator": k, + "context": context, + } + + # check on_pod_starting callback + assert mock_callbacks.on_pod_starting.call_count == 2 + assert mock_callbacks.on_pod_starting.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[1], + "operator": k, + "context": context, + } + + # check on_pod_completion callback + assert mock_callbacks.on_pod_completion.call_count == 2 + assert mock_callbacks.on_pod_completion.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + "operator": k, + "context": context, + } + + assert mock_callbacks.on_pod_wrapup.call_count == 2 + assert mock_callbacks.on_pod_wrapup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + "operator": k, + "context": context, + } + + # check on_pod_cleanup callback + assert mock_callbacks.on_pod_cleanup.call_count == 2 + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": k.pod, + "operator": k, + "context": context, + } + @patch(HOOK_CLASS, new=MagicMock) def test_execute_async_callbacks(self): from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode From 3f00f0efe6d789cab458a91d74b56a2ad14be2b5 Mon Sep 17 00:00:00 2001 From: John Horan Date: Tue, 17 Dec 2024 11:04:31 +0000 Subject: [PATCH 10/12] rename method --- .../src/airflow/providers/cncf/kubernetes/callbacks.py | 2 +- .../airflow/providers/cncf/kubernetes/operators/pod.py | 2 +- providers/tests/cncf/kubernetes/operators/test_pod.py | 8 ++++---- providers/tests/cncf/kubernetes/test_callbacks.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index c6871eee69c09..687ef1c3ccd4e 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -123,7 +123,7 @@ def on_pod_completion( pass @staticmethod - def on_pod_wrapup( + def on_pod_teardown( *, pod: k8s.V1Pod, client: client_type, diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index d673efee46eca..be77cf2af3db9 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -635,7 +635,7 @@ def execute_sync(self, context: Context): operator=self, ) for callback in self.callbacks: - callback.on_pod_wrapup( + callback.on_pod_teardown( pod=pod, client=self.client, mode=ExecutionMode.SYNC, diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 994ea89204637..5810a1c8e74ac 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1556,8 +1556,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "context": context, } - mock_callbacks.on_pod_wrapup.assert_called_once() - assert mock_callbacks.on_pod_wrapup.call_args.kwargs == { + mock_callbacks.on_pod_teardown.assert_called_once() + assert mock_callbacks.on_pod_teardown.call_args.kwargs == { "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[2], @@ -1643,8 +1643,8 @@ def test_execute_sync_multiple_callbacks(self, find_pod_mock): "context": context, } - assert mock_callbacks.on_pod_wrapup.call_count == 2 - assert mock_callbacks.on_pod_wrapup.call_args.kwargs == { + assert mock_callbacks.on_pod_teardown.call_count == 2 + assert mock_callbacks.on_pod_teardown.call_args.kwargs == { "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[2], diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/tests/cncf/kubernetes/test_callbacks.py index 1e7b541ba03a1..313fae8ed4ddf 100644 --- a/providers/tests/cncf/kubernetes/test_callbacks.py +++ b/providers/tests/cncf/kubernetes/test_callbacks.py @@ -69,5 +69,5 @@ def on_pod_manifest_created(*args, **kwargs) -> None: MockWrapper.mock_callbacks.on_pod_manifest_created(*args, **kwargs) @staticmethod - def on_pod_wrapup(*args, **kwargs) -> None: - MockWrapper.mock_callbacks.on_pod_wrapup(*args, **kwargs) + def on_pod_teardown(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_teardown(*args, **kwargs) From e650d6d66a17d0eb338d21d845b2f5cd67023564 Mon Sep 17 00:00:00 2001 From: John Horan Date: Sat, 21 Dec 2024 14:42:09 +0000 Subject: [PATCH 11/12] send client for consistency --- .../airflow/providers/cncf/kubernetes/callbacks.py | 9 ++++++++- .../providers/cncf/kubernetes/operators/pod.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index 687ef1c3ccd4e..d87e8065dbd1a 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -55,12 +55,19 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOpe @staticmethod def on_pod_manifest_created( - *, pod_request: k8s.V1Pod, mode: str, operator: KubernetesPodOperator, context: Context, **kwargs + *, + pod_request: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, ) -> None: """ Invoke this callback after KPO creates the V1Pod manifest but before the pod is created. :param pod_request: the kubernetes pod manifest + :param client: the Kubernetes client that can be used in the callback. :param mode: the current execution mode, it's one of (`sync`, `async`). """ pass diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index be77cf2af3db9..5e84ada7646dd 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -315,9 +315,9 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: list[type[KubernetesPodOperatorCallback]] - | type[KubernetesPodOperatorCallback] - | None = None, + callbacks: ( + list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None + ) = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, **kwargs, @@ -589,7 +589,11 @@ def execute_sync(self, context: Context): self.pod_request_obj = self.build_pod_request_obj(context) for callback in self.callbacks: callback.on_pod_manifest_created( - pod_request=self.pod_request_obj, mode=ExecutionMode.SYNC, context=context, operator=self + pod_request=self.pod_request_obj, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, ) if self.pod is None: self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` From fb41d157dc60f63fe6a64221af48c4701676c258 Mon Sep 17 00:00:00 2001 From: Elad Kalif <45845474+eladkal@users.noreply.github.com> Date: Mon, 27 Jan 2025 04:27:07 +0200 Subject: [PATCH 12/12] fix static checks --- providers/tests/cncf/kubernetes/operators/test_pod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 5810a1c8e74ac..851abc29db393 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -1602,7 +1602,7 @@ def test_execute_sync_multiple_callbacks(self, find_pod_mock): name="test", task_id="task", do_xcom_push=False, - callbacks=[MockKubernetesPodOperatorCallback, MockKubernetesPodOperatorCallback] + callbacks=[MockKubernetesPodOperatorCallback, MockKubernetesPodOperatorCallback], ) _, context = self.run_pod(k)