diff --git a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py index c334e858124e7..12a4ee5bb24fe 100644 --- a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import asyncio import json import logging import os @@ -183,6 +184,7 @@ def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection, ) assert not k.do_xcom_push + @pytest.mark.asyncio def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path): new_config_path = tmp_path / "kube_config.cfg" shutil.copy(kubeconfig_path, new_config_path) @@ -205,6 +207,7 @@ def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path): actual_pod = self.api_client.sanitize_for_serialization(k.pod) assert actual_pod == expected_pod + @pytest.mark.asyncio def test_working_pod(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -222,6 +225,7 @@ def test_working_pod(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] + @pytest.mark.asyncio def test_skip_cleanup(self, mock_get_connection): k = KubernetesPodOperator( namespace="unknown", @@ -237,6 +241,7 @@ def test_skip_cleanup(self, mock_get_connection): with pytest.raises(ApiException): k.execute(context) + @pytest.mark.asyncio def test_delete_operator_pod(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -255,6 +260,7 @@ def test_delete_operator_pod(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] + @pytest.mark.asyncio def test_skip_on_specified_exit_code(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -271,6 +277,7 @@ def test_skip_on_specified_exit_code(self, mock_get_connection): with pytest.raises(AirflowSkipException): k.execute(context) + @pytest.mark.asyncio def test_already_checked_on_success(self, mock_get_connection): """ When ``on_finish_action="keep_pod"``, pod should have 'already_checked' @@ -293,6 +300,7 @@ def test_already_checked_on_success(self, mock_get_connection): actual_pod = self.api_client.sanitize_for_serialization(actual_pod) assert actual_pod["metadata"]["labels"]["already_checked"] == "True" + @pytest.mark.asyncio def test_already_checked_on_failure(self, mock_get_connection): """ When ``on_finish_action="keep_pod"``, pod should have 'already_checked' @@ -318,6 +326,7 @@ def test_already_checked_on_failure(self, mock_get_connection): assert status["state"]["terminated"]["reason"] == "Error" assert actual_pod["metadata"]["labels"]["already_checked"] == "True" + @pytest.mark.asyncio def test_pod_hostnetwork(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -337,6 +346,7 @@ def test_pod_hostnetwork(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] + @pytest.mark.asyncio def test_pod_dnspolicy(self, mock_get_connection): dns_policy = "ClusterFirstWithHostNet" k = KubernetesPodOperator( @@ -359,6 +369,7 @@ def test_pod_dnspolicy(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] + @pytest.mark.asyncio def test_pod_schedulername(self, mock_get_connection): scheduler_name = "default-scheduler" k = KubernetesPodOperator( @@ -378,6 +389,7 @@ def test_pod_schedulername(self, mock_get_connection): self.expected_pod["spec"]["schedulerName"] = scheduler_name assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_pod_node_selector(self, mock_get_connection): node_selector = {"beta.kubernetes.io/os": "linux"} k = KubernetesPodOperator( @@ -397,6 +409,7 @@ def test_pod_node_selector(self, mock_get_connection): self.expected_pod["spec"]["nodeSelector"] = node_selector assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_pod_resources(self, mock_get_connection): resources = k8s.V1ResourceRequirements( requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"}, @@ -422,6 +435,7 @@ def test_pod_resources(self, mock_get_connection): } assert self.expected_pod == actual_pod + @pytest.mark.asyncio @pytest.mark.parametrize( "val", [ @@ -498,6 +512,7 @@ def test_pod_affinity(self, val, mock_get_connection): self.expected_pod["spec"]["affinity"] = expected assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_port(self, mock_get_connection): port = k8s.V1ContainerPort( name="http", @@ -521,6 +536,7 @@ def test_port(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["ports"] = [{"name": "http", "containerPort": 80}] assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_volume_mount(self, mock_get_connection): with mock.patch.object(PodManager, "log") as mock_logger: volume_mount = k8s.V1VolumeMount( @@ -560,6 +576,7 @@ def test_volume_mount(self, mock_get_connection): ] assert self.expected_pod == actual_pod + @pytest.mark.asyncio @pytest.mark.parametrize("uid", [0, 1000]) def test_run_as_user(self, uid, mock_get_connection): security_context = V1PodSecurityContext(run_as_user=uid) @@ -585,6 +602,7 @@ def test_run_as_user(self, uid, mock_get_connection): ) assert pod.to_dict()["spec"]["security_context"]["run_as_user"] == uid + @pytest.mark.asyncio @pytest.mark.parametrize("gid", [0, 1000]) def test_fs_group(self, gid, mock_get_connection): security_context = V1PodSecurityContext(fs_group=gid) @@ -610,6 +628,7 @@ def test_fs_group(self, gid, mock_get_connection): ) assert pod.to_dict()["spec"]["security_context"]["fs_group"] == gid + @pytest.mark.asyncio def test_disable_privilege_escalation(self, mock_get_connection): container_security_context = V1SecurityContext(allow_privilege_escalation=False) @@ -632,6 +651,7 @@ def test_disable_privilege_escalation(self, mock_get_connection): } assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_faulty_image(self, mock_get_connection): bad_image_name = "foobar" k = KubernetesPodOperator( @@ -670,6 +690,7 @@ def test_faulty_service_account(self, mock_get_connection): with pytest.raises(ApiException, match="error looking up service account default/foobar"): k.get_or_create_pod(pod, context) + @pytest.mark.asyncio def test_pod_failure(self, mock_get_connection): """ Tests that the task fails when a pod reports a failure @@ -692,6 +713,7 @@ def test_pod_failure(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["args"] = bad_internal_command assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_xcom_push(self, test_label, mock_get_connection): expected = {"test_label": test_label, "buzz": 2} args = [f"echo '{json.dumps(expected)}' > /airflow/xcom/return.json"] @@ -740,6 +762,7 @@ def test_env_vars(self, mock_get_connection): ] assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_pod_template_file_system(self, mock_get_connection, basic_pod_template): """Note: this test requires that you have a namespace ``mem-example`` in your cluster.""" k = KubernetesPodOperator( @@ -755,6 +778,7 @@ def test_pod_template_file_system(self, mock_get_connection, basic_pod_template) assert result is not None assert result == {"hello": "world"} + @pytest.mark.asyncio @pytest.mark.parametrize( "env_vars", [ @@ -790,6 +814,7 @@ def test_pod_template_file_with_overrides_system( assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} + @pytest.mark.asyncio def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connection, basic_pod_template): pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( @@ -830,6 +855,7 @@ def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connect assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} + @pytest.mark.asyncio def test_full_pod_spec(self, test_label, mock_get_connection): pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( @@ -875,6 +901,7 @@ def test_full_pod_spec(self, test_label, mock_get_connection): assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} + @pytest.mark.asyncio def test_init_container(self, mock_get_connection): # GIVEN volume_mounts = [ @@ -929,6 +956,7 @@ def test_init_container(self, mock_get_connection): ] assert self.expected_pod == actual_pod + @pytest.mark.asyncio @mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom") @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") @@ -1036,6 +1064,7 @@ def test_pod_template_file( del actual_pod["metadata"]["labels"]["airflow_version"] assert expected_dict == actual_pod + @pytest.mark.asyncio @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) @mock.patch(HOOK_CLASS) @@ -1070,6 +1099,7 @@ def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock): self.expected_pod["spec"]["priorityClassName"] = priority_class_name assert self.expected_pod == actual_pod + @pytest.mark.asyncio def test_pod_name(self, mock_get_connection): pod_name_too_long = "a" * 221 k = KubernetesPodOperator( @@ -1089,6 +1119,7 @@ def test_pod_name(self, mock_get_connection): with pytest.raises(AirflowException): k.execute(context) + @pytest.mark.asyncio def test_on_kill(self, mock_get_connection): hook = KubernetesHook(conn_id=None, in_cluster=False) client = hook.core_v1_client @@ -1129,6 +1160,7 @@ class ShortCircuitException(Exception): with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'): client.read_namespaced_pod(name=name, namespace=namespace) + @pytest.mark.asyncio def test_reattach_failing_pod_once(self, mock_get_connection): hook = KubernetesHook(conn_id=None, in_cluster=False) client = hook.core_v1_client @@ -1189,6 +1221,11 @@ def get_op(): # recreate op just to ensure we're not relying on any statefulness k = get_op() + # Before next attempt we need to re-create event loop if it is closed. + loop = asyncio.get_event_loop() + if loop.is_closed(): + asyncio.set_event_loop(asyncio.new_event_loop()) + # `create_pod` should be called because though there's still a pod to be found, # it will be `already_checked` with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock: @@ -1196,6 +1233,7 @@ def get_op(): k.execute(context) create_mock.assert_called_once() + @pytest.mark.asyncio def test_changing_base_container_name_with_get_logs(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -1221,6 +1259,7 @@ def test_changing_base_container_name_with_get_logs(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" assert self.expected_pod["spec"] == actual_pod["spec"] + @pytest.mark.asyncio def test_changing_base_container_name_no_logs(self, mock_get_connection): """ This test checks BOTH a modified base container name AND the get_logs=False flow, @@ -1251,6 +1290,7 @@ def test_changing_base_container_name_no_logs(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" assert self.expected_pod["spec"] == actual_pod["spec"] + @pytest.mark.asyncio def test_changing_base_container_name_no_logs_long(self, mock_get_connection): """ Similar to test_changing_base_container_name_no_logs, but ensures that @@ -1282,6 +1322,7 @@ def test_changing_base_container_name_no_logs_long(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"] assert self.expected_pod["spec"] == actual_pod["spec"] + @pytest.mark.asyncio def test_changing_base_container_name_failure(self, mock_get_connection): k = KubernetesPodOperator( namespace="default", @@ -1328,6 +1369,7 @@ class MyK8SPodOperator(KubernetesPodOperator): ) assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce" + @pytest.mark.asyncio def test_init_container_logs(self, mock_get_connection): marker_from_init_container = f"{uuid4()}" marker_from_main_container = f"{uuid4()}" @@ -1359,6 +1401,7 @@ def test_init_container_logs(self, mock_get_connection): assert marker_from_init_container in calls_args assert marker_from_main_container in calls_args + @pytest.mark.asyncio def test_init_container_logs_filtered(self, mock_get_connection): marker_from_init_container_to_log_1 = f"{uuid4()}" marker_from_init_container_to_log_2 = f"{uuid4()}" @@ -1456,6 +1499,7 @@ def __getattr__(self, name): class TestKubernetesPodOperator(BaseK8STest): + @pytest.mark.asyncio @pytest.mark.parametrize( "active_deadline_seconds,should_fail", [(3, True), (60, False)], diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index c315042cc7488..c24dcd716b0c9 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -18,6 +18,7 @@ from __future__ import annotations +import asyncio import datetime import json import logging @@ -602,12 +603,20 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s def await_pod_start(self, pod: k8s.V1Pod) -> None: try: - self.pod_manager.await_pod_start( - pod=pod, - schedule_timeout=self.schedule_timeout_seconds, - startup_timeout=self.startup_timeout_seconds, - check_interval=self.startup_check_interval_seconds, + loop = asyncio.get_event_loop() + events_task = asyncio.ensure_future( + self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds) + ) + loop.run_until_complete( + self.pod_manager.await_pod_start( + pod=pod, + schedule_timeout=self.schedule_timeout_seconds, + startup_timeout=self.startup_timeout_seconds, + check_interval=self.startup_check_interval_seconds, + ) ) + loop.run_until_complete(events_task) + loop.close() except PodLaunchFailedException: if self.log_events_on_failure: self._read_pod_events(pod, reraise=False) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 66cd4eefbe947..5942b68afa89e 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -18,6 +18,7 @@ from __future__ import annotations +import asyncio import enum import json import math @@ -49,6 +50,7 @@ from kubernetes.client.models.v1_container_state import V1ContainerState from kubernetes.client.models.v1_container_state_waiting import V1ContainerStateWaiting from kubernetes.client.models.v1_container_status import V1ContainerStatus + from kubernetes.client.models.v1_object_reference import V1ObjectReference from kubernetes.client.models.v1_pod import V1Pod from kubernetes.client.models.v1_pod_condition import V1PodCondition from urllib3.response import HTTPResponse @@ -377,7 +379,21 @@ def create_pod(self, pod: V1Pod) -> V1Pod: """Launch the pod asynchronously.""" return self.run_pod_async(pod) - def await_pod_start( + async def watch_pod_events(self, pod: V1Pod, check_interval: int = 1) -> None: + """Read pod events and writes into log.""" + self.keep_watching_for_events = True + num_events = 0 + while self.keep_watching_for_events: + events = self.read_pod_events(pod) + for new_event in events.items[num_events:]: + involved_object: V1ObjectReference = new_event.involved_object + self.log.info( + "The Pod has an Event: %s from %s", new_event.message, involved_object.field_path + ) + num_events = len(events.items) + await asyncio.sleep(check_interval) + + async def await_pod_start( self, pod: V1Pod, schedule_timeout: int = 120, startup_timeout: int = 120, check_interval: int = 1 ) -> None: """ @@ -439,7 +455,7 @@ def await_pod_start( f"\n{container_waiting.message}" ) - time.sleep(check_interval) + await asyncio.sleep(check_interval) def fetch_container_logs( self, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py index b5c4448e19301..8743444f089a8 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py @@ -19,6 +19,8 @@ import base64 import pickle +import pytest + from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -71,6 +73,7 @@ def f(): decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value)) assert decoded_input == {"args": [], "kwargs": {}} + @pytest.mark.asyncio def test_kubernetes_with_input_output(self): """Verify @task.kubernetes will run XCom container if do_xcom_push is set.""" with self.dag: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py index 9236ecf86e56a..a18c11abc7f7e 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py @@ -33,6 +33,7 @@ class TestKubernetesCmdDecorator(TestKubernetesDecoratorsBase): + @pytest.mark.asyncio @pytest.mark.parametrize( "args_only", [True, False], @@ -77,6 +78,7 @@ def hello(): assert containers[0].command == expected_command assert containers[0].args == expected_args + @pytest.mark.asyncio @pytest.mark.parametrize( "func_return, exception", [ @@ -117,6 +119,7 @@ def hello(): with context_manager: self.execute_task(k8s_task) + @pytest.mark.asyncio def test_kubernetes_cmd_with_input_output(self): """Verify @task.kubernetes_cmd will run XCom container if do_xcom_push is set.""" with self.dag: @@ -167,6 +170,7 @@ def f(arg1: str, arg2: str, kwarg1: str | None = None, kwarg2: str | None = None assert containers[1].image == XCOM_IMAGE assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom" + @pytest.mark.asyncio @pytest.mark.parametrize( "cmds", [None, ["ignored_cmd"], "ignored_cmd"], @@ -229,6 +233,7 @@ def hello(): assert containers[0].command == expected_command assert containers[0].args == expected_args + @pytest.mark.asyncio @pytest.mark.parametrize( argnames=["command", "op_arg", "expected_command"], argvalues=[ @@ -278,6 +283,7 @@ def hello(add_to_command: str): assert containers[0].command == expected_command assert containers[0].args == [] + @pytest.mark.asyncio def test_basic_context_works(self): """Test that decorator works with context as kwargs unpcacked in function arguments""" with self.dag: @@ -308,6 +314,7 @@ def hello(**context): assert containers[0].command == ["echo", "hello", DAG_ID] assert containers[0].args == [] + @pytest.mark.asyncio def test_named_context_variables(self): """Test that decorator works with specific context variable as kwargs in function arguments""" with self.dag: @@ -338,6 +345,7 @@ def hello(ti=None, dag_run=None): assert containers[0].command == ["echo", "hello", DAG_ID] assert containers[0].args == [] + @pytest.mark.asyncio def test_rendering_kubernetes_cmd_decorator_params(self): """Test that templating works in decorator parameters""" with self.dag: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py index 16db4b120fb27..015486157e0ae 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import asyncio from typing import Callable from unittest import mock @@ -90,7 +91,18 @@ def setup(self, dag_maker): self.dag = dag self.mock_create_pod = mock.patch(f"{POD_MANAGER_CLASS}.create_pod").start() - self.mock_await_pod_start = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start() + self.mock_watch_pod_events_patch = mock.patch( + f"{POD_MANAGER_CLASS}.watch_pod_events", new_callable=mock.AsyncMock + ) + self.mock_watch_pod_events = self.mock_watch_pod_events_patch.start() + self.mock_watch_pod_events.return_value = asyncio.Future() + self.mock_watch_pod_events.return_value.set_result(None) + self.mock_await_pod_start_patch = mock.patch( + f"{POD_MANAGER_CLASS}.await_pod_start", new_callable=mock.AsyncMock + ) + self.mock_await_pod_start = self.mock_await_pod_start_patch.start() + self.mock_await_pod_start.return_value = asyncio.Future() + self.mock_await_pod_start.return_value.set_result(None) self.mock_await_xcom_sidecar_container_start = mock.patch( f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start" ).start() @@ -108,6 +120,10 @@ def setup(self, dag_maker): self.mock_fetch_logs = mock.patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs").start() self.mock_fetch_logs.return_value = "logs" + yield + + mock.patch.stopall() + def teardown_method(self): clear_db_runs() clear_db_dags() @@ -185,6 +201,7 @@ def test_decorators_with_marked_as_teardown(self, task_decorator, decorator_name teardown_task = self.dag.task_group.children[TASK_FUNCTION_NAME_ID] assert teardown_task.is_teardown + @pytest.mark.asyncio @pytest.mark.parametrize( "name", ["no_name_in_args", None, "test_task_name"], diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index a0334293a3d9b..b6a014b647098 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import asyncio import datetime import re from contextlib import contextmanager, nullcontext @@ -145,11 +146,17 @@ class TestKubernetesPodOperator: @pytest.fixture(autouse=True) def setup_tests(self, dag_maker): self.create_pod_patch = patch(f"{POD_MANAGER_CLASS}.create_pod") - self.await_pod_patch = patch(f"{POD_MANAGER_CLASS}.await_pod_start") + self.watch_pod_events = patch(f"{POD_MANAGER_CLASS}.watch_pod_events", new_callable=mock.AsyncMock) + self.await_pod_patch = patch(f"{POD_MANAGER_CLASS}.await_pod_start", new_callable=mock.AsyncMock) self.await_pod_completion_patch = patch(f"{POD_MANAGER_CLASS}.await_pod_completion") self._default_client_patch = patch(f"{HOOK_CLASS}._get_default_client") + self.watch_pod_events_mock = self.watch_pod_events.start() + self.watch_pod_events_mock.return_value = asyncio.Future() + self.watch_pod_events_mock.return_value.set_result(None) self.create_mock = self.create_pod_patch.start() self.await_start_mock = self.await_pod_patch.start() + self.await_start_mock.return_value = asyncio.Future() + self.await_start_mock.return_value.set_result(None) self.await_pod_mock = self.await_pod_completion_patch.start() self._default_client_mock = self._default_client_patch.start() self.dag_maker = dag_maker @@ -372,6 +379,7 @@ def test_envs_from_secrets(self): k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref)) ] + @pytest.mark.asyncio @pytest.mark.parametrize(("in_cluster",), ([True], [False])) @patch(HOOK_CLASS) def test_labels(self, hook_mock, in_cluster): @@ -427,6 +435,7 @@ def test_find_custom_pod_labels(self): assert "foo=bar" in label_selector assert "hello=airflow" in label_selector + @pytest.mark.asyncio @patch(HOOK_CLASS, new=MagicMock) def test_find_pod_labels(self): k = KubernetesPodOperator( @@ -765,6 +774,7 @@ def test_termination_message_policy_default_value_correctly_set(self): pod = k.build_pod_request_obj(create_context(k)) assert pod.spec.containers[0].termination_message_policy == "File" + @pytest.mark.asyncio @pytest.mark.parametrize( "task_kwargs, base_container_fail, expect_to_delete_pod", [ @@ -877,6 +887,7 @@ def test_pod_delete_after_await_container_error( else: delete_pod_mock.assert_not_called() + @pytest.mark.asyncio @pytest.mark.parametrize("should_fail", [True, False]) @patch(f"{POD_MANAGER_CLASS}.delete_pod") @patch(f"{POD_MANAGER_CLASS}.await_pod_completion") @@ -1222,6 +1233,7 @@ def test_no_handle_failure_on_success(self, fetch_container_mock): # assert does not raise self.run_pod(k) + @pytest.mark.asyncio @pytest.mark.parametrize("randomize", [True, False]) @patch(f"{POD_MANAGER_CLASS}.await_container_completion", new=MagicMock) @patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs") @@ -1356,6 +1368,7 @@ def test_automount_service_account_token(self): assert isinstance(pod.spec.automount_service_account_token, bool) assert sanitized_pod["spec"]["automountServiceAccountToken"] == automount_service_account_token + @pytest.mark.asyncio @pytest.mark.parametrize("do_xcom_push", [True, False]) @patch(f"{POD_MANAGER_CLASS}.extract_xcom") @patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") @@ -1386,6 +1399,7 @@ def test_push_xcom_pod_info( assert pod_name == pod.metadata.name assert pod_namespace == pod.metadata.namespace + @pytest.mark.asyncio @patch(HOOK_CLASS, new=MagicMock) def test_previous_pods_ignored_for_reattached(self): """ @@ -1420,6 +1434,7 @@ def test_mark_checked_unexpected_exception( mock_patch_already_checked.assert_called_once() mock_delete_pod.assert_not_called() + @pytest.mark.asyncio @pytest.mark.parametrize("do_xcom_push", [True, False]) @patch(f"{POD_MANAGER_CLASS}.extract_xcom") @patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") @@ -1437,6 +1452,7 @@ def test_wait_for_xcom_sidecar_iff_push_xcom(self, mock_await, mock_extract_xcom else: mock_await.assert_not_called() + @pytest.mark.asyncio @pytest.mark.parametrize( "task_kwargs, should_fail, should_be_deleted", [ @@ -1533,6 +1549,7 @@ def test_task_id_as_name_dag_id_is_ignored(self): pod = k.build_pod_request_obj({}) assert re.match(r"a-very-reasonable-task-name-[a-z0-9-]+", pod.metadata.name) is not None + @pytest.mark.asyncio @pytest.mark.parametrize( "kwargs, actual_exit_code, expected_exc", [ @@ -1577,6 +1594,7 @@ def test_task_skip_when_pod_exit_with_certain_code( with pytest.raises(expected_exc): self.run_pod(k) + @pytest.mark.asyncio @patch(f"{POD_MANAGER_CLASS}.extract_xcom") @patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") @patch(f"{POD_MANAGER_CLASS}.await_container_completion") @@ -1706,6 +1724,7 @@ def test_execute_sync_callbacks(self, find_pod_mock): "context": context, } + @pytest.mark.asyncio @patch(HOOK_CLASS, new=MagicMock) @patch(KUB_OP_PATH.format("find_pod")) def test_execute_sync_multiple_callbacks(self, find_pod_mock): @@ -1851,6 +1870,7 @@ def test_execute_async_callbacks(self): "context": context, } + @pytest.mark.asyncio @pytest.mark.parametrize("get_logs", [True, False]) @patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs") @patch(f"{POD_MANAGER_CLASS}.await_container_completion") @@ -1886,6 +1906,7 @@ def test_await_container_completion_refreshes_properties_on_exception( assert hook != k.hook assert pod_manager != k.pod_manager + @pytest.mark.asyncio @patch(f"{POD_MANAGER_CLASS}.await_container_completion") @patch(f"{POD_MANAGER_CLASS}.read_pod") def test_await_container_completion_raises_unauthorized_if_credentials_still_invalid_after_refresh( @@ -1907,6 +1928,7 @@ def test_await_container_completion_raises_unauthorized_if_credentials_still_inv assert hook != k.hook assert pod_manager != k.pod_manager + @pytest.mark.asyncio @pytest.mark.parametrize( "side_effect, exception_type, expect_exc", [ diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index b7172f873461a..e0edaf89a6b8a 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import asyncio import copy import json from datetime import date @@ -39,6 +40,27 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" + + +@pytest.fixture(autouse=True, scope="module") +def patch_pod_manager_methods(): + # Patch watch_pod_events + patch_watch_pod_events = mock.patch(f"{POD_MANAGER_CLASS}.watch_pod_events", new_callable=mock.AsyncMock) + mock_watch_pod_events = patch_watch_pod_events.start() + mock_watch_pod_events.return_value = asyncio.Future() + mock_watch_pod_events.return_value.set_result(None) + + # Patch await_pod_start + patch_await_pod_start = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start", new_callable=mock.AsyncMock) + mock_await_pod_start = patch_await_pod_start.start() + mock_await_pod_start.return_value = asyncio.Future() + mock_await_pod_start.return_value.set_result(None) + + yield + + mock.patch.stopall() + @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") def test_spark_kubernetes_operator(mock_kubernetes_hook, data_file): @@ -203,7 +225,6 @@ def create_context(task): @pytest.mark.db_test @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_requested_container_logs") @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") -@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_start") @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") @patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") @@ -254,6 +275,7 @@ def call_commons(self): "version": "v1beta2", } + @pytest.mark.asyncio @pytest.mark.parametrize( "task_name, application_file_path", [ @@ -269,7 +291,6 @@ def test_create_application( mock_cleanup, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -294,6 +315,7 @@ def test_create_application( **self.call_commons, ) + @pytest.mark.asyncio @pytest.mark.parametrize( "task_name, application_file_path", [ @@ -309,7 +331,6 @@ def test_create_application_and_use_name_from_operator_args( mock_cleanup, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -339,6 +360,7 @@ def test_create_application_and_use_name_from_operator_args( **self.call_commons, ) + @pytest.mark.asyncio @pytest.mark.parametrize( "task_name, application_file_path", [ @@ -354,7 +376,6 @@ def test_create_application_and_use_name_task_id( mock_cleanup, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -381,6 +402,7 @@ def test_create_application_and_use_name_task_id( **self.call_commons, ) + @pytest.mark.asyncio @pytest.mark.parametrize("random_name_suffix", [True, False]) def test_new_template_from_yaml( self, @@ -389,7 +411,6 @@ def test_new_template_from_yaml( mock_cleanup, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -415,6 +436,7 @@ def test_new_template_from_yaml( **self.call_commons, ) + @pytest.mark.asyncio @pytest.mark.parametrize("random_name_suffix", [True, False]) def test_template_spec( self, @@ -423,7 +445,6 @@ def test_template_spec( mock_cleanup, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -454,7 +475,6 @@ def test_template_spec( @pytest.mark.db_test @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_requested_container_logs") @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") -@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_start") @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.create_job_name") @@ -488,6 +508,7 @@ def execute_operator(self, task_name, mock_create_job_name, job_spec): op.execute(context) return op + @pytest.mark.asyncio def test_env( self, mock_create_namespaced_crd, @@ -496,7 +517,6 @@ def test_env( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -532,6 +552,7 @@ def test_env( assert op.launcher.body["spec"]["driver"]["envFrom"] == env_from assert op.launcher.body["spec"]["executor"]["envFrom"] == env_from + @pytest.mark.asyncio def test_volume( self, mock_create_namespaced_crd, @@ -540,7 +561,6 @@ def test_volume( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -578,6 +598,7 @@ def test_volume( assert op.launcher.body["spec"]["driver"]["volumeMounts"] == volume_mounts assert op.launcher.body["spec"]["executor"]["volumeMounts"] == volume_mounts + @pytest.mark.asyncio def test_pull_secret( self, mock_create_namespaced_crd, @@ -586,7 +607,6 @@ def test_pull_secret( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -599,6 +619,7 @@ def test_pull_secret( exp_secrets = [k8s.V1LocalObjectReference(name=secret) for secret in ["secret1", "secret2"]] assert op.launcher.body["spec"]["imagePullSecrets"] == exp_secrets + @pytest.mark.asyncio def test_affinity( self, mock_create_namespaced_crd, @@ -607,7 +628,6 @@ def test_affinity( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -653,6 +673,7 @@ def test_affinity( assert op.launcher.body["spec"]["driver"]["affinity"] == affinity assert op.launcher.body["spec"]["executor"]["affinity"] == affinity + @pytest.mark.asyncio def test_toleration( self, mock_create_namespaced_crd, @@ -661,7 +682,6 @@ def test_toleration( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -680,6 +700,7 @@ def test_toleration( assert op.launcher.body["spec"]["driver"]["tolerations"] == [toleration] assert op.launcher.body["spec"]["executor"]["tolerations"] == [toleration] + @pytest.mark.asyncio def test_get_logs_from_driver( self, mock_create_namespaced_crd, @@ -688,7 +709,6 @@ def test_get_logs_from_driver( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, @@ -703,6 +723,7 @@ def test_get_logs_from_driver( follow_logs=True, ) + @pytest.mark.asyncio def test_find_custom_pod_labels( self, mock_create_namespaced_crd, @@ -711,7 +732,6 @@ def test_find_custom_pod_labels( mock_create_job_name, mock_get_kube_client, mock_create_pod, - mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py index 0ffbbd709bf75..e873e8bbb5bb2 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py @@ -22,7 +22,7 @@ from types import SimpleNamespace from typing import TYPE_CHECKING, cast from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, PropertyMock import pendulum import pytest @@ -160,6 +160,39 @@ def test_read_pod_logs_successfully_with_since_seconds(self): ] ) + @pytest.mark.asyncio + @mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) + async def test_watch_pod_events(self, mock_time_sleep): + events = mock.MagicMock() + events.items = [] + for id in ["event 1", "event 2"]: + event = mock.MagicMock() + event.message = f"test {id}" + event.involved_object.field_path = f"object {id}" + events.items.append(event) + startup_check_interval = 10 + + def mock_read_pod_events(pod): + self.pod_manager.keep_watching_for_events = False + return events + + with ( + mock.patch.object(self.pod_manager, "read_pod_events", side_effect=mock_read_pod_events), + mock.patch( + "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.log", + new_callable=PropertyMock, + ) as log_mock, + ): + await self.pod_manager.watch_pod_events(pod=None, check_interval=startup_check_interval) + + log_mock.return_value.info.assert_any_call( + "The Pod has an Event: %s from %s", "test event 1", "object event 1" + ) + log_mock.return_value.info.assert_any_call( + "The Pod has an Event: %s from %s", "test event 2", "object event 2" + ) + mock_time_sleep.assert_called_once_with(startup_check_interval) + def test_read_pod_events_successfully_returns_events(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.list_namespaced_event.return_value = mock.sentinel.events @@ -392,20 +425,22 @@ def test_start_pod_retries_three_times(self, mock_run_pod_async): assert mock_run_pod_async.call_count == 3 - def test_start_pod_raises_informative_error_on_scheduled_timeout(self): + @pytest.mark.asyncio + async def test_start_pod_raises_informative_error_on_scheduled_timeout(self): pod_response = mock.MagicMock() pod_response.status.phase = "Pending" self.mock_kube_client.read_namespaced_pod.return_value = pod_response expected_msg = "Pod took too long to be scheduled on the cluster, giving up. More than 0s. Check the pod events in kubernetes." mock_pod = MagicMock() with pytest.raises(AirflowException, match=expected_msg): - self.pod_manager.await_pod_start( + await self.pod_manager.await_pod_start( pod=mock_pod, schedule_timeout=0, startup_timeout=0, ) - def test_start_pod_raises_informative_error_on_startup_timeout(self): + @pytest.mark.asyncio + async def test_start_pod_raises_informative_error_on_startup_timeout(self): pod_response = mock.MagicMock() pod_response.status.phase = "Pending" condition = mock.MagicMock() @@ -417,13 +452,14 @@ def test_start_pod_raises_informative_error_on_startup_timeout(self): expected_msg = "Pod took too long to start. More than 0s. Check the pod events in kubernetes." mock_pod = MagicMock() with pytest.raises(AirflowException, match=expected_msg): - self.pod_manager.await_pod_start( + await self.pod_manager.await_pod_start( pod=mock_pod, schedule_timeout=0, startup_timeout=0, ) - def test_start_pod_raises_fast_error_on_image_error(self): + @pytest.mark.asyncio + async def test_start_pod_raises_fast_error_on_image_error(self): pod_response = mock.MagicMock() pod_response.status.phase = "Pending" container_statuse = mock.MagicMock() @@ -437,14 +473,15 @@ def test_start_pod_raises_fast_error_on_image_error(self): expected_msg = f"Pod docker image cannot be pulled, unable to start: {waiting_state.reason}\n{waiting_state.message}" mock_pod = MagicMock() with pytest.raises(AirflowException, match=expected_msg): - self.pod_manager.await_pod_start( + await self.pod_manager.await_pod_start( pod=mock_pod, schedule_timeout=60, startup_timeout=60, ) - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.time.sleep") - def test_start_pod_startup_interval_seconds(self, mock_time_sleep, caplog): + @pytest.mark.asyncio + @mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) + async def test_start_pod_startup_interval_seconds(self, mock_time_sleep, caplog): condition_scheduled = mock.MagicMock() condition_scheduled.type = "PodScheduled" condition_scheduled.status = "True" @@ -467,7 +504,7 @@ def pod_state_gen(): schedule_timeout = 30 startup_timeout = 60 mock_pod = MagicMock() - self.pod_manager.await_pod_start( + await self.pod_manager.await_pod_start( pod=mock_pod, schedule_timeout=schedule_timeout, # Never hit, any value is fine, as time.sleep is mocked to do nothing startup_timeout=startup_timeout, # Never hit, any value is fine, as time.sleep is mocked to do nothing @@ -477,6 +514,7 @@ def pod_state_gen(): assert mock_time_sleep.call_count == 3 assert f"::group::Waiting until {schedule_timeout}s to get the POD scheduled..." in caplog.text assert f"Waiting {startup_timeout}s to get the POD running..." in caplog.text + assert not self.pod_manager.keep_watching_for_events @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running") def test_container_is_running(self, container_is_running_mock):