diff --git a/chart/templates/rbac/pod-launcher-role.yaml b/chart/templates/rbac/pod-launcher-role.yaml index 454c1d5f31bf2..c6f3a54d19fba 100644 --- a/chart/templates/rbac/pod-launcher-role.yaml +++ b/chart/templates/rbac/pod-launcher-role.yaml @@ -76,4 +76,5 @@ rules: - "events" verbs: - "list" + - "watch" {{- end }} diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 0e105d241b18d..bcaa9d40b9317 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -20,7 +20,6 @@ import contextlib import json import tempfile -from collections.abc import Generator from functools import cached_property from time import sleep from typing import TYPE_CHECKING, Any, Protocol @@ -31,7 +30,7 @@ from kubernetes import client, config, utils, watch from kubernetes.client.models import V1Deployment from kubernetes.config import ConfigException -from kubernetes_asyncio import client as async_client, config as async_config +from kubernetes_asyncio import client as async_client, config as async_config, watch as async_watch from urllib3.exceptions import HTTPError from airflow.exceptions import AirflowException, AirflowNotFoundException @@ -47,8 +46,10 @@ from airflow.utils import yaml if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator + from kubernetes.client import V1JobList - from kubernetes.client.models import CoreV1EventList, V1Job, V1Pod + from kubernetes.client.models import CoreV1Event, CoreV1EventList, V1Job, V1Pod LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." @@ -783,6 +784,7 @@ def __init__(self, config_dict: dict | None = None, *args, **kwargs): self.config_dict = config_dict self._extras: dict | None = None + self._event_polling_fallback = False async def _load_config(self): """Return Kubernetes API session for use with requests.""" @@ -954,14 +956,24 @@ async def read_logs( raise KubernetesApiError from e @generic_api_retry - async def get_pod_events(self, name: str, namespace: str) -> CoreV1EventList: - """Get pod's events.""" + async def get_pod_events( + self, name: str, namespace: str, resource_version: str | None = None + ) -> CoreV1EventList: + """ + Get pod events. + + :param name: Pod name to get events for + :param namespace: Kubernetes namespace + :param resource_version: Only return events not older than this resource version + """ async with self.get_conn() as connection: try: v1_api = async_client.CoreV1Api(connection) events: CoreV1EventList = await v1_api.list_namespaced_event( field_selector=f"involvedObject.name={name}", namespace=namespace, + resource_version=resource_version, + resource_version_match="NotOlderThan" if resource_version else None, ) return events except HTTPError as e: @@ -969,6 +981,80 @@ async def get_pod_events(self, name: str, namespace: str) -> CoreV1EventList: raise KubernetesApiPermissionError("Permission denied (403) from Kubernetes API.") from e raise KubernetesApiError from e + @generic_api_retry + async def watch_pod_events( + self, + name: str, + namespace: str, + resource_version: str | None = None, + timeout_seconds: int = 30, + ) -> AsyncGenerator[CoreV1Event]: + """ + Watch pod events using Kubernetes Watch API. + + :param name: Pod name to watch events for + :param namespace: Kubernetes namespace + :param resource_version: Only return events not older than this resource version + :param timeout_seconds: Timeout in seconds for the watch stream + """ + if self._event_polling_fallback: + async for event_polled in self.watch_pod_events_polling_fallback( + name, namespace, resource_version, timeout_seconds + ): + yield event_polled + + try: + w = async_watch.Watch() + async with self.get_conn() as connection: + v1_api = async_client.CoreV1Api(connection) + + async for event_watched in w.stream( + v1_api.list_namespaced_event, + namespace=namespace, + field_selector=f"involvedObject.name={name}", + resource_version=resource_version, + timeout_seconds=timeout_seconds, + ): + event: CoreV1Event = event_watched.get("object") + yield event + + except async_client.exceptions.ApiException as e: + if hasattr(e, "status") and e.status == 403: + self.log.warning( + "Triggerer does not have Kubernetes API permission to 'watch' events: %s Falling back to polling.", + str(e), + ) + self._event_polling_fallback = True + async for event_polled in self.watch_pod_events_polling_fallback( + name, namespace, resource_version, timeout_seconds + ): + yield event_polled + + finally: + w.stop() + + async def watch_pod_events_polling_fallback( + self, + name: str, + namespace: str, + resource_version: str | None = None, + interval: int = 30, + ) -> AsyncGenerator[CoreV1Event]: + """ + Fallback method to poll pod event at regular intervals. + + This is required when the Airflow triggerer does not have permission to watch events. + + :param name: Pod name to watch events for + :param namespace: Kubernetes namespace + :param resource_version: Only return events not older than this resource version + :param interval: Polling interval in seconds + """ + events: CoreV1EventList = await self.get_pod_events(name, namespace, resource_version) + for event in events.items: + yield event + await asyncio.sleep(interval) + @generic_api_retry async def get_job_status(self, name: str, namespace: str) -> V1Job: """ diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index 08182a32d5dbe..ba13bbd9a3881 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -630,14 +630,26 @@ def await_pod_start(self, pod: k8s.V1Pod) -> None: try: async def _await_pod_start(): - events_task = self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds) - pod_start_task = self.pod_manager.await_pod_start( - pod=pod, - schedule_timeout=self.schedule_timeout_seconds, - startup_timeout=self.startup_timeout_seconds, - check_interval=self.startup_check_interval_seconds, + # Start event stream in background + events_task = asyncio.create_task( + self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds) ) - await asyncio.gather(pod_start_task, events_task) + + # Await pod start completion + try: + await self.pod_manager.await_pod_start( + pod=pod, + schedule_timeout=self.schedule_timeout_seconds, + startup_timeout=self.startup_timeout_seconds, + check_interval=self.startup_check_interval_seconds, + ) + finally: + # Stop watching events + events_task.cancel() + try: + await events_task + except asyncio.CancelledError: + pass asyncio.run(_await_pod_start()) except PodLaunchFailedException: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py index 29c1a382a0bf0..5aa496a772aa4 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -241,14 +241,25 @@ def _format_exception_description(self, exc: Exception) -> Any: async def _wait_for_pod_start(self) -> ContainerState: """Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error.""" pod = await self._get_pod() - events_task = self.pod_manager.watch_pod_events(pod, self.startup_check_interval) - pod_start_task = self.pod_manager.await_pod_start( - pod=pod, - schedule_timeout=self.schedule_timeout, - startup_timeout=self.startup_timeout, - check_interval=self.startup_check_interval, - ) - await asyncio.gather(pod_start_task, events_task) + # Start event stream in background + events_task = asyncio.create_task(self.pod_manager.watch_pod_events(pod, self.startup_check_interval)) + + # Await pod start completion + try: + await self.pod_manager.await_pod_start( + pod=pod, + schedule_timeout=self.schedule_timeout, + startup_timeout=self.startup_timeout, + check_interval=self.startup_check_interval, + ) + finally: + # Stop watching events + events_task.cancel() + try: + await events_task + except asyncio.CancelledError: + pass + return self.define_container_state(await self._get_pod()) async def _wait_for_container_completion(self) -> TriggerEvent: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 725ce6e48bca7..0e705a92617ce 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -56,6 +56,7 @@ from airflow.utils.timezone import utcnow if TYPE_CHECKING: + from kubernetes.client.models.core_v1_event import CoreV1Event from kubernetes.client.models.core_v1_event_list import CoreV1EventList from kubernetes.client.models.v1_container_state import V1ContainerState from kubernetes.client.models.v1_container_state_waiting import V1ContainerStateWaiting @@ -94,34 +95,21 @@ def check_exception_is_kubernetes_api_unauthorized(exc: BaseException): return isinstance(exc, ApiException) and exc.status and str(exc.status) == "401" -async def watch_pod_events( - pod_manager: PodManager | AsyncPodManager, - pod: V1Pod, - check_interval: float = 1, +def log_pod_event( + pod_manager: PodManager | AsyncPodManager, event: CoreV1Event, seen_events: set[str] ) -> None: """ - Read pod events and write them to the log. + Log a pod event if not already seen. - This function supports both asynchronous and synchronous pod managers. - - :param pod_manager: The pod manager instance (PodManager or AsyncPodManager). - :param pod: The pod object to monitor. - :param check_interval: Interval (in seconds) between checks. + :param pod_manager: The pod manager instance for logging + :param event: Kubernetes event + :param seen_events: Set of event UIDs already logged to avoid duplicates """ - num_events = 0 - is_async = isinstance(pod_manager, AsyncPodManager) - while not pod_manager.stop_watching_events: - if is_async: - events = await pod_manager.read_pod_events(pod) - else: - events = pod_manager.read_pod_events(pod) - for new_event in events.items[num_events:]: - involved_object: V1ObjectReference = new_event.involved_object - pod_manager.log.info( - "The Pod has an Event: %s from %s", new_event.message, involved_object.field_path - ) - num_events = len(events.items) - await asyncio.sleep(check_interval) + event_uid = event.metadata.uid + if event_uid not in seen_events: + seen_events.add(event_uid) + involved_object: V1ObjectReference = event.involved_object + pod_manager.log.info("The Pod has an Event: %s from %s", event.message, involved_object.field_path) async def await_pod_start( @@ -170,12 +158,14 @@ async def await_pod_start( pod_manager.log.info("Waiting %ss to get the POD running...", startup_timeout) if time.time() - start_check_time >= startup_timeout: + pod_manager.stop_watching_events = True pod_manager.log.info("::endgroup::") raise PodLaunchTimeoutException( f"Pod took too long to start. More than {startup_timeout}s. Check the pod events in kubernetes." ) else: if time.time() - start_check_time >= schedule_timeout: + pod_manager.stop_watching_events = True pod_manager.log.info("::endgroup::") raise PodLaunchTimeoutException( f"Pod took too long to be scheduled on the cluster, giving up. More than {schedule_timeout}s. Check the pod events in kubernetes." @@ -188,6 +178,7 @@ async def await_pod_start( container_waiting: V1ContainerStateWaiting | None = container_state.waiting if container_waiting: if container_waiting.reason in ["ErrImagePull", "InvalidImageName"]: + pod_manager.stop_watching_events = True pod_manager.log.info("::endgroup::") raise PodLaunchFailedException( f"Pod docker image cannot be pulled, unable to start: {container_waiting.reason}" @@ -354,9 +345,16 @@ def create_pod(self, pod: V1Pod) -> V1Pod: """Launch the pod asynchronously.""" return self.run_pod_async(pod) - async def watch_pod_events(self, pod: V1Pod, check_interval: int = 1) -> None: - """Read pod events and writes into log.""" - await watch_pod_events(pod_manager=self, pod=pod, check_interval=check_interval) + async def watch_pod_events(self, pod: V1Pod, check_interval: float = 10) -> None: + """Read pod events and write into log.""" + resource_version = None + seen_events: set[str] = set() + while not self.stop_watching_events: + events = self.read_pod_events(pod, resource_version) + for event in events.items: + log_pod_event(self, event, seen_events) + resource_version = event.metadata.resource_version + await asyncio.sleep(check_interval) async def await_pod_start( self, pod: V1Pod, schedule_timeout: int = 120, startup_timeout: int = 120, check_interval: int = 1 @@ -772,11 +770,20 @@ def get_container_names(self, pod: V1Pod) -> list[str]: ] @generic_api_retry - def read_pod_events(self, pod: V1Pod) -> CoreV1EventList: - """Read events from the POD.""" + def read_pod_events(self, pod: V1Pod, resource_version: str | None = None) -> CoreV1EventList: + """ + Read events from the POD with optimization parameters to reduce API load. + + :param pod: The pod to get events for + :param resource_version: Only return events newer than this resource version + :param limit: Maximum number of events to return + """ try: return self._client.list_namespaced_event( - namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}" + namespace=pod.metadata.namespace, + field_selector=f"involvedObject.name={pod.metadata.name}", + resource_version=resource_version, + resource_version_match="NotOlderThan" if resource_version else None, ) except HTTPError as e: raise KubernetesApiException(f"There was an error reading the kubernetes API: {e}") @@ -978,16 +985,28 @@ async def read_pod(self, pod: V1Pod) -> V1Pod: pod.metadata.namespace, ) - async def read_pod_events(self, pod: V1Pod) -> CoreV1EventList: + async def read_pod_events(self, pod: V1Pod, resource_version: str | None = None) -> CoreV1EventList: """Get pod's events.""" return await self._hook.get_pod_events( pod.metadata.name, pod.metadata.namespace, + resource_version=resource_version, ) - async def watch_pod_events(self, pod: V1Pod, check_interval: float = 1) -> None: - """Read pod events and writes into log.""" - await watch_pod_events(pod_manager=self, pod=pod, check_interval=check_interval) + async def watch_pod_events(self, pod: V1Pod, startup_check_interval: float = 30) -> None: + """Watch pod events and write to log.""" + seen_events: set[str] = set() + resource_version = None + while not self.stop_watching_events: + async for event in self._hook.watch_pod_events( + name=pod.metadata.name, + namespace=pod.metadata.namespace, + resource_version=resource_version, + timeout_seconds=startup_check_interval, + ): + if event: + log_pod_event(self, event, seen_events) + resource_version = event.metadata.resource_version async def await_pod_start( self, pod: V1Pod, schedule_timeout: int = 120, startup_timeout: int = 120, check_interval: float = 1 diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py index 6d13f5eb38f21..f587ec0803511 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py @@ -32,6 +32,7 @@ from kubernetes.client import V1Deployment, V1DeploymentStatus from kubernetes.client.rest import ApiException from kubernetes.config import ConfigException +from kubernetes_asyncio import client as async_client from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models import Connection @@ -1011,6 +1012,216 @@ async def test_load_config_with_several_params(self, sdk_connection_not_found): with pytest.raises(AirflowException): await hook._load_config() + @pytest.mark.asyncio + @mock.patch(KUBE_API.format("list_namespaced_event")) + async def test_async_get_pod_events_with_resource_version( + self, mock_list_namespaced_event, kube_config_loader + ): + """Test getting pod events with resource_version parameter.""" + mock_event = mock.Mock() + mock_event.metadata.name = "test-event" + mock_events = mock.Mock() + mock_events.items = [mock_event] + mock_list_namespaced_event.return_value = self.mock_await_result(mock_events) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + result = await hook.get_pod_events(name=POD_NAME, namespace=NAMESPACE, resource_version="12345") + + mock_list_namespaced_event.assert_called_once_with( + field_selector=f"involvedObject.name={POD_NAME}", + namespace=NAMESPACE, + resource_version="12345", + resource_version_match="NotOlderThan", + ) + assert result == mock_events + + @pytest.mark.asyncio + @mock.patch(KUBE_API.format("list_namespaced_event")) + async def test_async_get_pod_events_without_resource_version( + self, mock_list_namespaced_event, kube_config_loader + ): + """Test getting pod events without resource_version parameter.""" + mock_event = mock.Mock() + mock_event.metadata.name = "test-event" + mock_events = mock.Mock() + mock_events.items = [mock_event] + mock_list_namespaced_event.return_value = self.mock_await_result(mock_events) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + result = await hook.get_pod_events(name=POD_NAME, namespace=NAMESPACE) + + mock_list_namespaced_event.assert_called_once_with( + field_selector=f"involvedObject.name={POD_NAME}", + namespace=NAMESPACE, + resource_version=None, + resource_version_match=None, + ) + assert result == mock_events + + @pytest.mark.asyncio + @mock.patch("kubernetes_asyncio.watch.Watch") + @mock.patch(KUBE_API.format("list_namespaced_event")) + async def test_async_watch_pod_events( + self, mock_list_namespaced_event, mock_watch_class, kube_config_loader + ): + """Test watching pod events using Watch API.""" + mock_event1 = mock.Mock() + mock_event1.metadata.uid = "event-1" + mock_event2 = mock.Mock() + mock_event2.metadata.uid = "event-2" + + async def async_generator(*_, **__): + yield {"object": mock_event1} + yield {"object": mock_event2} + + mock_watch = mock.Mock() + mock_watch_class.return_value = mock_watch + mock_watch.stream = mock.Mock(side_effect=async_generator) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + events = [] + async for event in hook.watch_pod_events( + name=POD_NAME, namespace=NAMESPACE, resource_version="12345", timeout_seconds=30 + ): + events.append(event) + + assert len(events) == 2 + assert events[0] == mock_event1 + assert events[1] == mock_event2 + mock_watch.stop.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("kubernetes_asyncio.watch.Watch") + @mock.patch(KUBE_API.format("list_namespaced_event")) + async def test_async_watch_pod_events_permission_error_fallback( + self, mock_list_namespaced_event, mock_watch_class, kube_config_loader + ): + """Test fallback to polling when watch permission is denied.""" + + # Simulate permission error on watch + async def async_generator_with_error(*_, **__): + raise async_client.exceptions.ApiException(status=403) + yield + + mock_watch = mock.Mock() + mock_watch_class.return_value = mock_watch + mock_watch.stream = mock.Mock(side_effect=async_generator_with_error) + + # Setup fallback polling + mock_event = mock.Mock() + mock_event.metadata.uid = "event-1" + mock_events = mock.Mock() + mock_events.items = [mock_event] + mock_list_namespaced_event.return_value = self.mock_await_result(mock_events) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + events = [] + async for event in hook.watch_pod_events( + name=POD_NAME, namespace=NAMESPACE, resource_version="12345", timeout_seconds=30 + ): + events.append(event) + break + + assert len(events) == 1 + assert events[0] == mock_event + assert hook._event_polling_fallback is True + + @pytest.mark.asyncio + @mock.patch(KUBE_API.format("list_namespaced_event")) + @mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) + async def test_async_watch_pod_events_polling_fallback( + self, mock_sleep, mock_list_namespaced_event, kube_config_loader + ): + """Test polling fallback method.""" + mock_event1 = mock.Mock() + mock_event1.metadata.uid = "event-1" + mock_event2 = mock.Mock() + mock_event2.metadata.uid = "event-2" + mock_events = mock.Mock() + mock_events.items = [mock_event1, mock_event2] + mock_list_namespaced_event.return_value = self.mock_await_result(mock_events) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + events = [] + async for event in hook.watch_pod_events_polling_fallback( + name=POD_NAME, namespace=NAMESPACE, resource_version="12345", interval=10 + ): + events.append(event) + + assert len(events) == 2 + assert events[0] == mock_event1 + assert events[1] == mock_event2 + mock_list_namespaced_event.assert_called_once_with( + field_selector=f"involvedObject.name={POD_NAME}", + namespace=NAMESPACE, + resource_version="12345", + resource_version_match="NotOlderThan", + ) + mock_sleep.assert_called_once_with(10) + + @pytest.mark.asyncio + @mock.patch("kubernetes_asyncio.watch.Watch") + @mock.patch(KUBE_API.format("list_namespaced_event")) + async def test_async_watch_pod_events_uses_fallback_if_already_set( + self, mock_list_namespaced_event, mock_watch_class, kube_config_loader + ): + """Test that watch uses polling fallback if flag is already set.""" + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + hook._event_polling_fallback = True + + mock_event = mock.Mock() + mock_event.metadata.uid = "event-1" + mock_events = mock.Mock() + mock_events.items = [mock_event] + mock_list_namespaced_event.return_value = self.mock_await_result(mock_events) + + events = [] + async for event in hook.watch_pod_events(name=POD_NAME, namespace=NAMESPACE, timeout_seconds=30): + events.append(event) + break + + # Watch API should not be called + mock_watch_class.assert_not_called() + # Polling should be used + assert len(events) == 1 + assert events[0] == mock_event + @pytest.mark.asyncio @mock.patch(KUBE_API.format("read_namespaced_pod")) async def test_get_pod(self, lib_method, kube_config_loader): diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py index 406d7f9d02c01..f2ce3b00dc3eb 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py @@ -358,16 +358,30 @@ def test_define_container_state_should_execute_successfully( @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_read_events_during_start(self, mock_hook, mock_method, trigger): - event1 = mock.AsyncMock() + event1 = mock.Mock() + event1.metadata.uid = "event-uid-1" + event1.metadata.resource_version = "100" event1.message = "event 1" event1.involved_object.field_path = "object 1" - event2 = mock.AsyncMock() + event2 = mock.Mock() + event2.metadata.uid = "event-uid-2" + event2.metadata.resource_version = "101" event2.message = "event 2" event2.involved_object.field_path = "object 2" - events_list = mock.AsyncMock() - events_list.items = [event1, event2] - mock_hook.get_pod_events = mock.AsyncMock(return_value=events_list) + call_count = 0 + + async def async_event_generator(*_, **__): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: return events + yield event1 + yield event2 + # Subsequent calls: return nothing and stop watching + trigger.pod_manager.stop_watching_events = True + + mock_hook.watch_pod_events = mock.Mock(side_effect=async_event_generator) pod_pending = mock.MagicMock() pod_pending.status.phase = PodPhase.PENDING diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py index 1688470946be1..c3a4f45524f36 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py @@ -36,6 +36,7 @@ PodLogsConsumer, PodManager, PodPhase, + log_pod_event, parse_log_line, ) from airflow.utils.timezone import utc @@ -58,6 +59,65 @@ def test_parse_log_line(): assert line == log_message +def test_log_pod_event(): + """Test logging a pod event.""" + mock_pod_manager = mock.Mock() + mock_event = mock.Mock() + mock_event.metadata.uid = "event-uid-1" + mock_event.message = "Test event message" + mock_event.involved_object.field_path = "Test field path" + + seen_events = set() + + log_pod_event(mock_pod_manager, mock_event, seen_events) + + assert "event-uid-1" in seen_events + mock_pod_manager.log.info.assert_called_once_with( + "The Pod has an Event: %s from %s", "Test event message", "Test field path" + ) + + +def test_log_pod_event_skips_duplicate(): + """Test that duplicate events are skipped.""" + mock_pod_manager = mock.Mock() + mock_event = mock.Mock() + mock_event.metadata.uid = "event-uid-1" + mock_event.message = "Test event message" + + seen_events = {"event-uid-1"} # Event already seen + + log_pod_event(mock_pod_manager, mock_event, seen_events) + + assert "event-uid-1" in seen_events + mock_pod_manager.log.info.assert_not_called() + + +def test_log_pod_event_multiple_events(): + """Test logging multiple different events.""" + mock_pod_manager = mock.Mock() + seen_events = set() + + # First event + mock_event1 = mock.Mock() + mock_event1.metadata.uid = "event-uid-1" + mock_event1.message = "First message" + mock_event1.involved_object.field_path = "Test field path 1" + + log_pod_event(mock_pod_manager, mock_event1, seen_events) + assert "event-uid-1" in seen_events + + # Second event + mock_event2 = mock.Mock() + mock_event2.metadata.uid = "event-uid-2" + mock_event2.message = "Second message" + mock_event2.involved_object.field_path = "Test field path 2" + + log_pod_event(mock_pod_manager, mock_event2, seen_events) + assert "event-uid-2" in seen_events + assert len(seen_events) == 2 + assert mock_pod_manager.log.info.call_count == 2 + + class TestPodManager: def setup_method(self): self.mock_kube_client = mock.Mock() @@ -183,7 +243,7 @@ async def test_watch_pod_events(self, mock_time_sleep): events.items.append(event) startup_check_interval = 10 - def mock_read_pod_events(pod): + def mock_read_pod_events(*_, **__): self.pod_manager.stop_watching_events = True return events @@ -210,6 +270,130 @@ def test_read_pod_events_successfully_returns_events(self): events = self.pod_manager.read_pod_events(mock.sentinel) assert mock.sentinel.events == events + def test_read_pod_events_with_resource_version(self): + """Test reading pod events with resource_version parameter.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + mock_events = mock.Mock() + self.mock_kube_client.list_namespaced_event.return_value = mock_events + + events = self.pod_manager.read_pod_events(mock_pod, resource_version="12345") + + assert events == mock_events + self.mock_kube_client.list_namespaced_event.assert_called_once_with( + namespace="test-namespace", + field_selector="involvedObject.name=test-pod", + resource_version="12345", + resource_version_match="NotOlderThan", + ) + + def test_read_pod_events_without_resource_version(self): + """Test reading pod events without resource_version parameter.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + mock_events = mock.Mock() + self.mock_kube_client.list_namespaced_event.return_value = mock_events + + events = self.pod_manager.read_pod_events(mock_pod) + + assert events == mock_events + self.mock_kube_client.list_namespaced_event.assert_called_once_with( + namespace="test-namespace", + field_selector="involvedObject.name=test-pod", + resource_version=None, + resource_version_match=None, + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) + async def test_watch_pod_events_tracks_resource_version(self, mock_sleep): + """Test that watch_pod_events tracks resource version.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + mock_event_1 = mock.Mock() + mock_event_1.metadata.uid = "event-uid-1" + mock_event_1.metadata.resource_version = "100" + mock_event_1.message = "Event 1" + mock_event_1.involved_object.field_path = "spec" + + mock_events_1 = mock.Mock() + mock_events_1.items = [mock_event_1] + + mock_event_2 = mock.Mock() + mock_event_2.metadata.uid = "event-uid-2" + mock_event_2.metadata.resource_version = "101" + mock_event_2.message = "Event 2" + mock_event_2.involved_object.field_path = "spec" + + mock_events_2 = mock.Mock() + mock_events_2.items = [mock_event_2] + + self.mock_kube_client.list_namespaced_event.side_effect = [mock_events_1, mock_events_2] + self.pod_manager.stop_watching_events = False + + call_count = 0 + + async def side_effect_sleep(*_, **__): + nonlocal call_count + call_count += 1 + if call_count >= 2: + self.pod_manager.stop_watching_events = True + + mock_sleep.side_effect = side_effect_sleep + + await self.pod_manager.watch_pod_events(mock_pod, check_interval=1) + + # Check that resource_version was passed in second call + calls = self.mock_kube_client.list_namespaced_event.call_args_list + assert len(calls) == 2 + # First call should have no resource_version + assert calls[0][1]["resource_version"] is None + # Second call should use resource_version from first event + assert calls[1][1]["resource_version"] == "100" + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) + async def test_watch_pod_events_deduplicates_events(self, mock_sleep): + """Test that watch_pod_events deduplicates events.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + mock_event = mock.Mock() + mock_event.metadata.uid = "event-uid-1" + mock_event.metadata.resource_version = "100" + mock_event.message = "Duplicate event" + mock_event.involved_object.field_path = "spec" + + mock_events = mock.Mock() + mock_events.items = [mock_event] + + # Will return the same event on each invocation + self.mock_kube_client.list_namespaced_event.return_value = mock_events + self.pod_manager.stop_watching_events = False + + call_count = 0 + + async def side_effect_sleep(*_, **__): + nonlocal call_count + call_count += 1 + if call_count >= 2: + # Stop after 2 iterations -> same event is returned 2 times + self.pod_manager.stop_watching_events = True + + mock_sleep.side_effect = side_effect_sleep + + with mock.patch.object(self.pod_manager.log, "info") as mock_log_info: + await self.pod_manager.watch_pod_events(mock_pod, check_interval=1) + + # Event should only be logged once despite being returned twice + assert mock_log_info.call_count == 1 + mock_log_info.assert_called_with("The Pod has an Event: %s from %s", "Duplicate event", "spec") + def test_read_pod_events_retries_successfully(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.list_namespaced_event.side_effect = [ @@ -223,10 +407,14 @@ def test_read_pod_events_retries_successfully(self): mock.call( namespace=mock.sentinel.metadata.namespace, field_selector=f"involvedObject.name={mock.sentinel.metadata.name}", + resource_version=None, + resource_version_match=None, ), mock.call( namespace=mock.sentinel.metadata.namespace, field_selector=f"involvedObject.name={mock.sentinel.metadata.name}", + resource_version=None, + resource_version_match=None, ), ] ) @@ -730,6 +918,178 @@ def setup_method(self): callbacks=[], ) + @pytest.mark.asyncio + async def test_read_pod_events_with_resource_version(self): + """Test async read_pod_events with resource_version parameter.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + mock_events = mock.Mock() + + self.mock_async_hook.get_pod_events.return_value = mock_events + + result = await self.async_pod_manager.read_pod_events(mock_pod, resource_version="12345") + + assert result == mock_events + self.mock_async_hook.get_pod_events.assert_called_once_with( + "test-pod", "test-namespace", resource_version="12345" + ) + + @pytest.mark.asyncio + async def test_read_pod_events_without_resource_version(self): + """Test async read_pod_events without resource_version parameter.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + mock_events = mock.Mock() + + self.mock_async_hook.get_pod_events.return_value = mock_events + + result = await self.async_pod_manager.read_pod_events(mock_pod) + + assert result == mock_events + self.mock_async_hook.get_pod_events.assert_called_once_with( + "test-pod", "test-namespace", resource_version=None + ) + + @pytest.mark.asyncio + async def test_watch_pod_events_uses_hook_watch(self): + """Test that watch_pod_events uses hook's watch_pod_events method.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + mock_event1 = mock.Mock() + mock_event1.metadata.uid = "event-uid-1" + mock_event1.metadata.resource_version = "100" + mock_event1.message = "Event 1" + mock_event1.involved_object.field_path = "spec" + + mock_event2 = mock.Mock() + mock_event2.metadata.uid = "event-uid-2" + mock_event2.metadata.resource_version = "101" + mock_event2.message = "Event 2" + mock_event2.involved_object.field_path = "spec" + + async def async_event_generator(*_, **__): + yield mock_event1 + yield mock_event2 + self.async_pod_manager.stop_watching_events = True + + self.mock_async_hook.watch_pod_events = mock.Mock(side_effect=async_event_generator) + + with mock.patch.object(self.async_pod_manager.log, "info") as mock_log_info: + await self.async_pod_manager.watch_pod_events(mock_pod, startup_check_interval=30) + + # Both events should be logged + assert mock_log_info.call_count == 2 + calls = mock_log_info.call_args_list + assert calls[0][0] == ("The Pod has an Event: %s from %s", "Event 1", "spec") + assert calls[1][0] == ("The Pod has an Event: %s from %s", "Event 2", "spec") + + # Verify hook was called + self.mock_async_hook.watch_pod_events.assert_called() + + @pytest.mark.asyncio + async def test_watch_pod_events_tracks_resource_version(self): + """Test that watch_pod_events tracks and updates resource version.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + # Create events for two iterations + mock_event1 = mock.Mock() + mock_event1.metadata.uid = "event-uid-1" + mock_event1.metadata.resource_version = "100" + mock_event1.message = "Event 1" + mock_event1.involved_object.field_path = "spec" + + mock_event2 = mock.Mock() + mock_event2.metadata.uid = "event-uid-2" + mock_event2.metadata.resource_version = "101" + mock_event2.message = "Event 2" + mock_event2.involved_object.field_path = "spec" + + call_count = 0 + + async def async_event_generator(*_, **__): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First iteration + yield mock_event1 + else: + # Second iteration + yield mock_event2 + self.async_pod_manager.stop_watching_events = True + + self.mock_async_hook.watch_pod_events = mock.Mock(side_effect=async_event_generator) + self.async_pod_manager.stop_watching_events = False + + await self.async_pod_manager.watch_pod_events(mock_pod, startup_check_interval=30) + + # Verify hook was called twice with updated resource_version + assert self.mock_async_hook.watch_pod_events.call_count == 2 + calls = self.mock_async_hook.watch_pod_events.call_args_list + + # First call should have no resource_version + assert calls[0][1]["resource_version"] is None + # Second call should use resource_version from first event + assert calls[1][1]["resource_version"] == "100" + + @pytest.mark.asyncio + async def test_watch_pod_events_deduplicates_events(self): + """Test that watch_pod_events deduplicates events across iterations.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + # Same event returned in two iterations + mock_event = mock.Mock() + mock_event.metadata.uid = "event-uid-1" + mock_event.metadata.resource_version = "100" + mock_event.message = "Duplicate event" + mock_event.involved_object.field_path = "spec" + + call_count = 0 + + async def async_event_generator(*_, **__): + nonlocal call_count + call_count += 1 + yield mock_event # Return same event + if call_count >= 2: + self.async_pod_manager.stop_watching_events = True + + self.mock_async_hook.watch_pod_events = mock.Mock(side_effect=async_event_generator) + self.async_pod_manager.stop_watching_events = False + + with mock.patch.object(self.async_pod_manager.log, "info") as mock_log_info: + await self.async_pod_manager.watch_pod_events(mock_pod, startup_check_interval=30) + + # Event should only be logged once despite being returned twice + assert mock_log_info.call_count == 1 + mock_log_info.assert_called_with("The Pod has an Event: %s from %s", "Duplicate event", "spec") + + @pytest.mark.asyncio + async def test_watch_pod_events_handles_none_event(self): + """Test that watch_pod_events handles None events gracefully.""" + mock_pod = mock.Mock() + mock_pod.metadata.namespace = "test-namespace" + mock_pod.metadata.name = "test-pod" + + async def async_event_generator(*_, **__): + yield None # None event should be skipped + self.async_pod_manager.stop_watching_events = True + + self.mock_async_hook.watch_pod_events = mock.Mock(side_effect=async_event_generator) + self.async_pod_manager.stop_watching_events = False + + with mock.patch.object(self.async_pod_manager.log, "info") as mock_log_info: + await self.async_pod_manager.watch_pod_events(mock_pod, startup_check_interval=30) + + # No events should be logged for None + mock_log_info.assert_not_called() + @pytest.mark.asyncio async def test_start_pod_raises_informative_error_on_scheduled_timeout(self): pod_response = mock.MagicMock() @@ -844,16 +1204,18 @@ async def test_watch_pod_events(self, mock_time_sleep, mock_log_info): events.items.append(event) startup_check_interval = 10 - def get_pod_events_side_effect(name, namespace): + async def watch_events_generator(*_, **__): + for event in events.items: + yield event self.async_pod_manager.stop_watching_events = True - return events - self.mock_async_hook.get_pod_events.side_effect = get_pod_events_side_effect + self.mock_async_hook.watch_pod_events = mock.Mock(side_effect=watch_events_generator) - await self.async_pod_manager.watch_pod_events(pod=mock_pod, check_interval=startup_check_interval) + await self.async_pod_manager.watch_pod_events( + pod=mock_pod, startup_check_interval=startup_check_interval + ) mock_log_info.assert_any_call("The Pod has an Event: %s from %s", "test event 1", "object event 1") mock_log_info.assert_any_call("The Pod has an Event: %s from %s", "test event 2", "object event 2") - mock_time_sleep.assert_called_once_with(startup_check_interval) @pytest.mark.asyncio @pytest.mark.parametrize(