diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 47ef247fac6c2..30499d2c47790 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -36,7 +36,11 @@ from airflow.models import Connection from airflow.providers.cncf.kubernetes.exceptions import KubernetesApiError, KubernetesApiPermissionError from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive -from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import generic_api_retry +from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( + API_TIMEOUT, + API_TIMEOUT_OFFSET_SERVER_SIDE, + generic_api_retry, +) from airflow.providers.cncf.kubernetes.utils.container import ( container_is_completed, container_is_running, @@ -68,6 +72,31 @@ def _load_body_to_dict(body: str) -> dict: return body_dict +def _get_request_timeout(timeout_seconds: int | None) -> float: + """Get the client-side request timeout.""" + if timeout_seconds is not None and timeout_seconds > API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE: + return timeout_seconds + API_TIMEOUT_OFFSET_SERVER_SIDE + return API_TIMEOUT + + +class _TimeoutK8sApiClient(client.ApiClient): + """Wrapper around kubernetes sync ApiClient to set default timeout.""" + + def call_api(self, *args, **kwargs): + timeout_seconds = kwargs.get("timeout_seconds") # get server-side timeout + kwargs.setdefault("_request_timeout", _get_request_timeout(timeout_seconds)) # client-side timeout + return super().call_api(*args, **kwargs) + + +class _TimeoutAsyncK8sApiClient(async_client.ApiClient): + """Wrapper around kubernetes async ApiClient to set default timeout.""" + + async def call_api(self, *args, **kwargs): + timeout_seconds = kwargs.get("timeout_seconds") # server-side timeout + kwargs.setdefault("_request_timeout", _get_request_timeout(timeout_seconds)) # client-side timeout + return await super().call_api(*args, **kwargs) + + class PodOperatorHookProtocol(Protocol): """ Protocol to define methods relied upon by KubernetesPodOperator. @@ -272,7 +301,7 @@ def get_conn(self) -> client.ApiClient: self.log.debug("loading kube_config from: in_cluster configuration") self._is_in_cluster = True config.load_incluster_config() - return client.ApiClient() + return _TimeoutK8sApiClient() if kubeconfig_path is not None: self.log.debug("loading kube_config from: %s", kubeconfig_path) @@ -282,7 +311,7 @@ def get_conn(self) -> client.ApiClient: client_configuration=self.client_configuration, context=cluster_context, ) - return client.ApiClient() + return _TimeoutK8sApiClient() if kubeconfig is not None: with tempfile.NamedTemporaryFile() as temp_config: @@ -297,7 +326,7 @@ def get_conn(self) -> client.ApiClient: client_configuration=self.client_configuration, context=cluster_context, ) - return client.ApiClient() + return _TimeoutK8sApiClient() if self.config_dict: self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary")) @@ -307,7 +336,7 @@ def get_conn(self) -> client.ApiClient: client_configuration=self.client_configuration, context=cluster_context, ) - return client.ApiClient() + return _TimeoutK8sApiClient() return self._get_default_client(cluster_context=cluster_context) @@ -326,7 +355,7 @@ def _get_default_client(self, *, cluster_context: str | None = None) -> client.A client_configuration=self.client_configuration, context=cluster_context, ) - return client.ApiClient() + return _TimeoutK8sApiClient() @property def is_in_cluster(self) -> bool: @@ -803,7 +832,7 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None): client_configuration=self.client_configuration, context=cluster_context, ) - return async_client.ApiClient() + return _TimeoutAsyncK8sApiClient() if num_selected_configuration > 1: raise AirflowException( @@ -816,13 +845,13 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None): self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("within a pod")) self._is_in_cluster = True async_config.load_incluster_config() - return async_client.ApiClient() + return _TimeoutAsyncK8sApiClient() if self.config_dict: self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary")) self._is_in_cluster = False await async_config.load_kube_config_from_dict(self.config_dict, context=cluster_context) - return async_client.ApiClient() + return _TimeoutAsyncK8sApiClient() if kubeconfig_path is not None: self.log.debug("loading kube_config from: %s", kubeconfig_path) @@ -877,7 +906,7 @@ async def _get_field(self, field_name): async def get_conn(self) -> async_client.ApiClient: kube_client = None try: - kube_client = await self._load_config() or async_client.ApiClient() + kube_client = await self._load_config() or _TimeoutAsyncK8sApiClient() yield kube_client finally: if kube_client is not None: @@ -996,7 +1025,7 @@ async def watch_pod_events( :param name: Pod name to watch events for :param namespace: Kubernetes namespace :param resource_version: Only return events not older than this resource version - :param timeout_seconds: Timeout in seconds for the watch stream + :param timeout_seconds: Timeout in seconds for the watch stream. A small additional buffer may be applied internally. """ if self._event_polling_fallback: async for event_polled in self.watch_pod_events_polling_fallback( diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index 3646ff8b75c31..c2334b1c9a2e2 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -52,6 +52,8 @@ class KubernetesApiException(AirflowException): """When communication with kubernetes API fails.""" +API_TIMEOUT = 60 # allow 1 min of timeout for kubernetes api calls +API_TIMEOUT_OFFSET_SERVER_SIDE = 5 # offset to the server side timeout for the client side timeout API_RETRIES = conf.getint("workers", "api_retries", fallback=5) API_RETRY_WAIT_MIN = conf.getfloat("workers", "api_retry_wait_min", fallback=1) API_RETRY_WAIT_MAX = conf.getfloat("workers", "api_retry_wait_max", fallback=15) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py index ea0a64938c69c..d1fac78704989 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py @@ -35,7 +35,16 @@ from kubernetes_asyncio import client as async_client from airflow.models import Connection -from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook +from airflow.providers.cncf.kubernetes.hooks.kubernetes import ( + AsyncKubernetesHook, + KubernetesHook, + _TimeoutAsyncK8sApiClient, + _TimeoutK8sApiClient, +) +from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( + API_TIMEOUT, + API_TIMEOUT_OFFSET_SERVER_SIDE, +) from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException from tests_common.test_utils.db import clear_test_connections @@ -78,6 +87,75 @@ class DeprecationRemovalRequired(AirflowException): ... DEFAULT_CONN_ID = "kubernetes_default" +class TestTimeoutK8sApiClient: + @pytest.mark.parametrize( + ("kwargs", "expected_timeout"), + [ + pytest.param({}, API_TIMEOUT, id="default-timeout"), + pytest.param({"timeout_seconds": 5678, "_request_timeout": 1234}, 1234, id="explicit-timeout"), + pytest.param( + {"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE}, + API_TIMEOUT, + id="server-side-timeout-limit", + ), + pytest.param( + {"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE + 1}, + API_TIMEOUT + 1, + id="server-side-timeout-above-limit", + ), + ], + ) + def test_call_api_timeout_inject(self, kwargs, expected_timeout): + with mock.patch("kubernetes.client.ApiClient.call_api") as mocked_call_api: + mocked_call_api.return_value = "ok" + cli = _TimeoutK8sApiClient() + + out = cli.call_api("arg1", kwargs_arg1="fake", **kwargs) + + mocked_call_api.assert_called_once() + call_args, call_kwargs = mocked_call_api.call_args + assert call_args[0] == "arg1" + assert call_kwargs["kwargs_arg1"] == "fake" + assert call_kwargs["_request_timeout"] == expected_timeout + assert out == "ok" + + +class TestTimeoutAsyncK8sApiClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("kwargs", "expected_timeout"), + [ + pytest.param({}, API_TIMEOUT, id="default-timeout"), + pytest.param({"timeout_seconds": 5678, "_request_timeout": 1234}, 1234, id="explicit-timeout"), + pytest.param( + {"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE}, + API_TIMEOUT, + id="server-side-timeout-limit", + ), + pytest.param( + {"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE + 1}, + API_TIMEOUT + 1, + id="server-side-timeout-above-limit", + ), + ], + ) + async def test_call_api_timeout_inject(self, kwargs, expected_timeout): + with mock.patch( + "kubernetes_asyncio.client.ApiClient.call_api", new_callable=mock.AsyncMock + ) as mocked_call_api: + mocked_call_api.return_value = "ok" + cli = _TimeoutAsyncK8sApiClient() + + out = await cli.call_api("arg1", kwargs_arg1="fake", **kwargs) + + mocked_call_api.assert_called_once() + call_args, call_kwargs = mocked_call_api.call_args + assert call_args[0] == "arg1" + assert call_kwargs["kwargs_arg1"] == "fake" + assert call_kwargs["_request_timeout"] == expected_timeout + assert out == "ok" + + @pytest.fixture def remove_default_conn(monkeypatch): original_env_var = os.environ.get(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}")