diff --git a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py index eb971b4a30469..d87e8065dbd1a 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/callbacks.py +++ b/providers/src/airflow/providers/cncf/kubernetes/callbacks.py @@ -17,11 +17,15 @@ from __future__ import annotations from enum import Enum -from typing import Union +from typing import TYPE_CHECKING, Union import kubernetes.client as k8s import kubernetes_asyncio.client as async_k8s +if TYPE_CHECKING: + from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator + from airflow.utils.context import Context + client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api] @@ -41,7 +45,7 @@ class KubernetesPodOperatorCallback: """ @staticmethod - def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: + def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOperator, **kwargs) -> None: """ Invoke this callback after creating the sync client. @@ -50,7 +54,34 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: pass @staticmethod - def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_manifest_created( + *, + pod_request: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: + """ + Invoke this callback after KPO creates the V1Pod manifest but before the pod is created. + + :param pod_request: the kubernetes pod manifest + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_creation( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback after creating the pod. @@ -61,7 +92,15 @@ def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) pass @staticmethod - def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_starting( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback when the pod starts. @@ -72,7 +111,15 @@ def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) pass @staticmethod - def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + def on_pod_completion( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: """ Invoke this callback when the pod completes. @@ -83,7 +130,34 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg pass @staticmethod - def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): + def on_pod_teardown( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ) -> None: + """ + Invoke this callback after all pod completion callbacks but before the pod is deleted. + + :param pod: the completed pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_cleanup( + *, + pod: k8s.V1Pod, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, + ): """ Invoke this callback after cleaning/deleting the pod. @@ -95,7 +169,14 @@ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): @staticmethod def on_operator_resuming( - *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs + *, + pod: k8s.V1Pod, + event: dict, + client: client_type, + mode: str, + operator: KubernetesPodOperator, + context: Context, + **kwargs, ) -> None: """ Invoke this callback when resuming the `KubernetesPodOperator` from deferred state. diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 7243afe07ab9a..4fe70f0233c98 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -324,7 +324,9 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: type[KubernetesPodOperatorCallback] | None = None, + callbacks: ( + list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None + ) = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, **kwargs, @@ -415,7 +417,7 @@ def __init__( self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback - self.callbacks = callbacks + self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks] self._killed: bool = False @cached_property @@ -519,8 +521,9 @@ def hook(self) -> PodOperatorHookProtocol: @cached_property def client(self) -> CoreV1Api: client = self.hook.core_v1_client - if self.callbacks: - self.callbacks.on_sync_client_creation(client=client) + + for callback in self.callbacks: + callback.on_sync_client_creation(client=client, operator=self) return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: @@ -594,6 +597,14 @@ def execute_sync(self, context: Context): try: if self.pod_request_obj is None: self.pod_request_obj = self.build_pod_request_obj(context) + for callback in self.callbacks: + callback.on_pod_manifest_created( + pod_request=self.pod_request_obj, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) if self.pod is None: self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` pod_request_obj=self.pod_request_obj, @@ -606,28 +617,48 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) - if self.callbacks: - self.callbacks.on_pod_creation( - pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC + for callback in self.callbacks: + callback.on_pod_creation( + pod=self.remote_pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, ) self.await_init_containers_completion(pod=self.pod) self.await_pod_start(pod=self.pod) if self.callbacks: - self.callbacks.on_pod_starting( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod = self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_starting( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) self.await_pod_completion(pod=self.pod) if self.callbacks: - self.callbacks.on_pod_completion( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod = self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_completion( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) + for callback in self.callbacks: + callback.on_pod_teardown( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) if self.do_xcom_push: self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) @@ -642,8 +673,14 @@ def execute_sync(self, context: Context): pod=pod_to_clean, remote_pod=self.remote_pod, ) - if self.callbacks: - self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) + for callback in self.callbacks: + callback.on_pod_cleanup( + pod=pod_to_clean, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) if self.do_xcom_push: return result @@ -710,11 +747,15 @@ def execute_async(self, context: Context) -> None: context=context, ) if self.callbacks: - self.callbacks.on_pod_creation( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + pod = self.find_pod(self.pod.metadata.namespace, context=context) + for callback in self.callbacks: + callback.on_pod_creation( + pod=pod, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) @@ -775,10 +816,16 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: if not self.pod: raise PodNotFoundException("Could not find pod after resuming from deferral") - if self.callbacks and event["status"] != "running": - self.callbacks.on_operator_resuming( - pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC - ) + if event["status"] != "running": + for callback in self.callbacks: + callback.on_operator_resuming( + pod=self.pod, + event=event, + client=self.client, + mode=ExecutionMode.SYNC, + context=context, + operator=self, + ) follow = self.logging_interval is None last_log_time = event.get("last_log_time") @@ -821,9 +868,9 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: except TaskDeferred: raise finally: - self._clean(event) + self._clean(event, context) - def _clean(self, event: dict[str, Any]) -> None: + def _clean(self, event: dict[str, Any], context: Context) -> None: if event["status"] == "running": return istio_enabled = self.is_istio_enabled(self.pod) @@ -846,6 +893,7 @@ def _clean(self, event: dict[str, Any]) -> None: self.post_complete_action( pod=self.pod, remote_pod=self.pod, + context=context, ) def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None: @@ -875,14 +923,16 @@ def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime e if not isinstance(e, ApiException) else e.reason, ) - def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None: + def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None: """Actions that must be done after operator finishes logic of the deferrable_execution.""" self.cleanup( pod=pod, remote_pod=remote_pod, ) - if self.callbacks: - self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) + for callback in self.callbacks: + callback.on_pod_cleanup( + pod=pod, client=self.client, mode=ExecutionMode.SYNC, operator=self, context=context + ) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): # Skip cleaning the pod in the following scenarios. diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index a6ec3cd694d46..199d6a6d35dd3 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -320,7 +320,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: type[KubernetesPodOperatorCallback] | None = None, + callbacks: list[type[KubernetesPodOperatorCallback]] | None = None, ): """ Create the launcher. @@ -331,7 +331,7 @@ def __init__( super().__init__() self._client = kube_client self._watch = watch.Watch() - self._callbacks = callbacks + self._callbacks = callbacks or [] def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: """Run POD asynchronously.""" @@ -466,8 +466,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None progress_callback_lines.append(line) else: # previous log line is complete for line in progress_callback_lines: - if self._callbacks: - self._callbacks.progress_callback( + for callback in self._callbacks: + callback.progress_callback( line=line, client=self._client, mode=ExecutionMode.SYNC ) if message_to_log is not None: @@ -485,8 +485,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None finally: # log the last line and update the last_captured_timestamp for line in progress_callback_lines: - if self._callbacks: - self._callbacks.progress_callback( + for callback in self._callbacks: + callback.progress_callback( line=line, client=self._client, mode=ExecutionMode.SYNC ) if message_to_log is not None: diff --git a/providers/tests/cncf/kubernetes/operators/test_pod.py b/providers/tests/cncf/kubernetes/operators/test_pod.py index 5cc4d2c7ba1fd..851abc29db393 100644 --- a/providers/tests/cncf/kubernetes/operators/test_pod.py +++ b/providers/tests/cncf/kubernetes/operators/test_pod.py @@ -20,6 +20,7 @@ import re from contextlib import contextmanager, nullcontext from io import BytesIO +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, mock_open, patch @@ -47,6 +48,9 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodLoggingStatus, PodPhase from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils import timezone + +if TYPE_CHECKING: + from airflow.utils.context import Context from airflow.utils.session import create_session from airflow.utils.types import DagRunType @@ -196,7 +200,7 @@ def test_templates(self, create_task_instance_of_operator, session): assert dag_id == rendered.volumes[0].name assert dag_id == rendered.volumes[0].config_map.name - def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V1Pod: + def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> tuple[k8s.V1Pod, Context]: with self.dag_maker(dag_id="dag") as dag: operator.dag = dag @@ -211,7 +215,7 @@ def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V remote_pod_mock.status.phase = "Succeeded" self.await_pod_mock.return_value = remote_pod_mock operator.execute(context=context) - return self.await_start_mock.call_args.kwargs["pod"] + return self.await_start_mock.call_args.kwargs["pod"], context def sanitize_for_serialization(self, obj): return ApiClient().sanitize_for_serialization(obj) @@ -357,7 +361,7 @@ def test_labels(self, hook_mock, in_cluster): in_cluster=in_cluster, do_xcom_push=False, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) assert pod.metadata.labels == { "foo": "bar", "dag_id": "dag", @@ -1247,7 +1251,7 @@ def test_push_xcom_pod_info( do_xcom_push=do_xcom_push, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_name") pod_namespace = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", key="pod_namespace") assert pod_name == pod.metadata.name @@ -1475,7 +1479,7 @@ def test_get_logs_but_not_for_base_container( remote_pod_mock = MagicMock() remote_pod_mock.status.phase = "Succeeded" self.await_pod_mock.return_value = remote_pod_mock - pod = self.run_pod(k) + pod, _ = self.run_pod(k) # check that the base container is not included in the logs mock_fetch_log.assert_called_once_with(pod=pod, containers=["some_init_container"], follow_logs=True) @@ -1513,11 +1517,14 @@ def test_execute_sync_callbacks(self, find_pod_mock): do_xcom_push=False, callbacks=MockKubernetesPodOperatorCallback, ) - self.run_pod(k) + _, context = self.run_pod(k) # check on_sync_client_creation callback mock_callbacks.on_sync_client_creation.assert_called_once() - assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client, "operator": k} + + # check on_pod_manifest_created callback + mock_callbacks.on_pod_manifest_created.assert_called_once() # check on_pod_creation callback mock_callbacks.on_pod_creation.assert_called_once() @@ -1525,6 +1532,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[0], + "operator": k, + "context": context, } # check on_pod_starting callback @@ -1533,6 +1542,8 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[1], + "operator": k, + "context": context, } # check on_pod_completion callback @@ -1541,6 +1552,17 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": found_pods[2], + "operator": k, + "context": context, + } + + mock_callbacks.on_pod_teardown.assert_called_once() + assert mock_callbacks.on_pod_teardown.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + "operator": k, + "context": context, } # check on_pod_cleanup callback @@ -1549,6 +1571,95 @@ def test_execute_sync_callbacks(self, find_pod_mock): "client": k.client, "mode": ExecutionMode.SYNC, "pod": k.pod, + "operator": k, + "context": context, + } + + @patch(HOOK_CLASS, new=MagicMock) + @patch(KUB_OP_PATH.format("find_pod")) + def test_execute_sync_multiple_callbacks(self, find_pod_mock): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from providers.tests.cncf.kubernetes.test_callbacks import ( + MockKubernetesPodOperatorCallback, + MockWrapper, + ) + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + found_pods = [MagicMock(), MagicMock(), MagicMock()] + find_pod_mock.side_effect = [None] + found_pods + + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=[MockKubernetesPodOperatorCallback, MockKubernetesPodOperatorCallback], + ) + _, context = self.run_pod(k) + + # check on_sync_client_creation callback + assert mock_callbacks.on_sync_client_creation.call_count == 2 + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client, "operator": k} + + # check on_pod_manifest_created callback + assert mock_callbacks.on_pod_manifest_created.call_count == 2 + + # check on_pod_creation callback + assert mock_callbacks.on_pod_creation.call_count == 2 + assert mock_callbacks.on_pod_creation.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[0], + "operator": k, + "context": context, + } + + # check on_pod_starting callback + assert mock_callbacks.on_pod_starting.call_count == 2 + assert mock_callbacks.on_pod_starting.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[1], + "operator": k, + "context": context, + } + + # check on_pod_completion callback + assert mock_callbacks.on_pod_completion.call_count == 2 + assert mock_callbacks.on_pod_completion.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + "operator": k, + "context": context, + } + + assert mock_callbacks.on_pod_teardown.call_count == 2 + assert mock_callbacks.on_pod_teardown.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + "operator": k, + "context": context, + } + + # check on_pod_cleanup callback + assert mock_callbacks.on_pod_cleanup.call_count == 2 + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": k.pod, + "operator": k, + "context": context, } @patch(HOOK_CLASS, new=MagicMock) @@ -1577,8 +1688,10 @@ def test_execute_async_callbacks(self): do_xcom_push=False, callbacks=MockKubernetesPodOperatorCallback, ) + context = create_context(k) + k.trigger_reentry( - context=create_context(k), + context=context, event={ "status": "success", "message": TEST_SUCCESS_MESSAGE, @@ -1593,6 +1706,8 @@ def test_execute_async_callbacks(self): "client": k.client, "mode": ExecutionMode.SYNC, "pod": remote_pod_mock, + "operator": k, + "context": context, } # check on_pod_cleanup callback @@ -1601,6 +1716,8 @@ def test_execute_async_callbacks(self): "client": k.client, "mode": ExecutionMode.SYNC, "pod": remote_pod_mock, + "operator": k, + "context": context, } @pytest.mark.parametrize("get_logs", [True, False]) @@ -1611,7 +1728,7 @@ def test_await_container_completion_refreshes_properties_on_exception( self, mock_read_pod, mock_await_container_completion, fetch_requested_container_logs, get_logs ): k = KubernetesPodOperator(task_id="task", get_logs=get_logs) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) client, hook, pod_manager = k.client, k.hook, k.pod_manager # no exception doesn't update properties @@ -1644,7 +1761,7 @@ def test_await_container_completion_raises_unauthorized_if_credentials_still_inv self, mock_read_pod, mock_await_container_completion ): k = KubernetesPodOperator(task_id="task", get_logs=False) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) client, hook, pod_manager = k.client, k.hook, k.pod_manager mock_await_container_completion.side_effect = [ApiException(status=401)] @@ -1677,7 +1794,7 @@ def test_await_container_completion_retries_on_specific_exception( task_id="task", get_logs=False, ) - pod = self.run_pod(k) + pod, _ = self.run_pod(k) mock_await_container_completion.side_effect = side_effect if expect_exc: k.await_pod_completion(pod) diff --git a/providers/tests/cncf/kubernetes/test_callbacks.py b/providers/tests/cncf/kubernetes/test_callbacks.py index 2757b8296a7d0..313fae8ed4ddf 100644 --- a/providers/tests/cncf/kubernetes/test_callbacks.py +++ b/providers/tests/cncf/kubernetes/test_callbacks.py @@ -63,3 +63,11 @@ def on_operator_resuming(*args, **kwargs) -> None: @staticmethod def progress_callback(*args, **kwargs) -> None: MockWrapper.mock_callbacks.progress_callback(*args, **kwargs) + + @staticmethod + def on_pod_manifest_created(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_manifest_created(*args, **kwargs) + + @staticmethod + def on_pod_teardown(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_teardown(*args, **kwargs) diff --git a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py index a3d937f095add..24ea794c2bffd 100644 --- a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py @@ -52,7 +52,7 @@ def setup_method(self): self.mock_kube_client = mock.Mock() self.pod_manager = PodManager( kube_client=self.mock_kube_client, - callbacks=MockKubernetesPodOperatorCallback, + callbacks=[MockKubernetesPodOperatorCallback], ) def test_read_pod_logs_successfully_returns_logs(self):