Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from airflow.models import Connection
from airflow.providers.cncf.kubernetes.exceptions import KubernetesApiError, KubernetesApiPermissionError
from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import generic_api_retry
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
API_TIMEOUT,
API_TIMEOUT_OFFSET_SERVER_SIDE,
generic_api_retry,
)
from airflow.providers.cncf.kubernetes.utils.container import (
container_is_completed,
container_is_running,
Expand Down Expand Up @@ -68,6 +72,31 @@ def _load_body_to_dict(body: str) -> dict:
return body_dict


def _get_request_timeout(timeout_seconds: int | None) -> float:
"""Get the client-side request timeout."""
if timeout_seconds is not None and timeout_seconds > API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE:
return timeout_seconds + API_TIMEOUT_OFFSET_SERVER_SIDE
return API_TIMEOUT


class _TimeoutK8sApiClient(client.ApiClient):
"""Wrapper around kubernetes sync ApiClient to set default timeout."""

def call_api(self, *args, **kwargs):
timeout_seconds = kwargs.get("timeout_seconds") # get server-side timeout
kwargs.setdefault("_request_timeout", _get_request_timeout(timeout_seconds)) # client-side timeout
return super().call_api(*args, **kwargs)


class _TimeoutAsyncK8sApiClient(async_client.ApiClient):
"""Wrapper around kubernetes async ApiClient to set default timeout."""

async def call_api(self, *args, **kwargs):
timeout_seconds = kwargs.get("timeout_seconds") # server-side timeout
kwargs.setdefault("_request_timeout", _get_request_timeout(timeout_seconds)) # client-side timeout
return await super().call_api(*args, **kwargs)


class PodOperatorHookProtocol(Protocol):
"""
Protocol to define methods relied upon by KubernetesPodOperator.
Expand Down Expand Up @@ -272,7 +301,7 @@ def get_conn(self) -> client.ApiClient:
self.log.debug("loading kube_config from: in_cluster configuration")
self._is_in_cluster = True
config.load_incluster_config()
return client.ApiClient()
return _TimeoutK8sApiClient()

if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
Expand All @@ -282,7 +311,7 @@ def get_conn(self) -> client.ApiClient:
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()
return _TimeoutK8sApiClient()

if kubeconfig is not None:
with tempfile.NamedTemporaryFile() as temp_config:
Expand All @@ -297,7 +326,7 @@ def get_conn(self) -> client.ApiClient:
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()
return _TimeoutK8sApiClient()

if self.config_dict:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary"))
Expand All @@ -307,7 +336,7 @@ def get_conn(self) -> client.ApiClient:
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()
return _TimeoutK8sApiClient()

return self._get_default_client(cluster_context=cluster_context)

Expand All @@ -326,7 +355,7 @@ def _get_default_client(self, *, cluster_context: str | None = None) -> client.A
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()
return _TimeoutK8sApiClient()

@property
def is_in_cluster(self) -> bool:
Expand Down Expand Up @@ -803,7 +832,7 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None):
client_configuration=self.client_configuration,
context=cluster_context,
)
return async_client.ApiClient()
return _TimeoutAsyncK8sApiClient()

if num_selected_configuration > 1:
raise AirflowException(
Expand All @@ -816,13 +845,13 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None):
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("within a pod"))
self._is_in_cluster = True
async_config.load_incluster_config()
return async_client.ApiClient()
return _TimeoutAsyncK8sApiClient()

if self.config_dict:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary"))
self._is_in_cluster = False
await async_config.load_kube_config_from_dict(self.config_dict, context=cluster_context)
return async_client.ApiClient()
return _TimeoutAsyncK8sApiClient()

if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
Expand Down Expand Up @@ -877,7 +906,7 @@ async def _get_field(self, field_name):
async def get_conn(self) -> async_client.ApiClient:
kube_client = None
try:
kube_client = await self._load_config() or async_client.ApiClient()
kube_client = await self._load_config() or _TimeoutAsyncK8sApiClient()
yield kube_client
finally:
if kube_client is not None:
Expand Down Expand Up @@ -996,7 +1025,7 @@ async def watch_pod_events(
:param name: Pod name to watch events for
:param namespace: Kubernetes namespace
:param resource_version: Only return events not older than this resource version
:param timeout_seconds: Timeout in seconds for the watch stream
:param timeout_seconds: Timeout in seconds for the watch stream. A small additional buffer may be applied internally.
"""
if self._event_polling_fallback:
async for event_polled in self.watch_pod_events_polling_fallback(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class KubernetesApiException(AirflowException):
"""When communication with kubernetes API fails."""


API_TIMEOUT = 60 # allow 1 min of timeout for kubernetes api calls
API_TIMEOUT_OFFSET_SERVER_SIDE = 5 # offset to the server side timeout for the client side timeout
API_RETRIES = conf.getint("workers", "api_retries", fallback=5)
API_RETRY_WAIT_MIN = conf.getfloat("workers", "api_retry_wait_min", fallback=1)
API_RETRY_WAIT_MAX = conf.getfloat("workers", "api_retry_wait_max", fallback=15)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@
from kubernetes_asyncio import client as async_client

from airflow.models import Connection
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook
from airflow.providers.cncf.kubernetes.hooks.kubernetes import (
AsyncKubernetesHook,
KubernetesHook,
_TimeoutAsyncK8sApiClient,
_TimeoutK8sApiClient,
)
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
API_TIMEOUT,
API_TIMEOUT_OFFSET_SERVER_SIDE,
)
from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException

from tests_common.test_utils.db import clear_test_connections
Expand Down Expand Up @@ -78,6 +87,75 @@ class DeprecationRemovalRequired(AirflowException): ...
DEFAULT_CONN_ID = "kubernetes_default"


class TestTimeoutK8sApiClient:
@pytest.mark.parametrize(
("kwargs", "expected_timeout"),
[
pytest.param({}, API_TIMEOUT, id="default-timeout"),
pytest.param({"timeout_seconds": 5678, "_request_timeout": 1234}, 1234, id="explicit-timeout"),
pytest.param(
{"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE},
API_TIMEOUT,
id="server-side-timeout-limit",
),
pytest.param(
{"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE + 1},
API_TIMEOUT + 1,
id="server-side-timeout-above-limit",
),
],
)
def test_call_api_timeout_inject(self, kwargs, expected_timeout):
with mock.patch("kubernetes.client.ApiClient.call_api") as mocked_call_api:
mocked_call_api.return_value = "ok"
cli = _TimeoutK8sApiClient()

out = cli.call_api("arg1", kwargs_arg1="fake", **kwargs)

mocked_call_api.assert_called_once()
call_args, call_kwargs = mocked_call_api.call_args
assert call_args[0] == "arg1"
assert call_kwargs["kwargs_arg1"] == "fake"
assert call_kwargs["_request_timeout"] == expected_timeout
assert out == "ok"


class TestTimeoutAsyncK8sApiClient:
@pytest.mark.asyncio
@pytest.mark.parametrize(
("kwargs", "expected_timeout"),
[
pytest.param({}, API_TIMEOUT, id="default-timeout"),
pytest.param({"timeout_seconds": 5678, "_request_timeout": 1234}, 1234, id="explicit-timeout"),
pytest.param(
{"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE},
API_TIMEOUT,
id="server-side-timeout-limit",
),
pytest.param(
{"timeout_seconds": API_TIMEOUT - API_TIMEOUT_OFFSET_SERVER_SIDE + 1},
API_TIMEOUT + 1,
id="server-side-timeout-above-limit",
),
],
)
async def test_call_api_timeout_inject(self, kwargs, expected_timeout):
with mock.patch(
"kubernetes_asyncio.client.ApiClient.call_api", new_callable=mock.AsyncMock
) as mocked_call_api:
mocked_call_api.return_value = "ok"
cli = _TimeoutAsyncK8sApiClient()

out = await cli.call_api("arg1", kwargs_arg1="fake", **kwargs)

mocked_call_api.assert_called_once()
call_args, call_kwargs = mocked_call_api.call_args
assert call_args[0] == "arg1"
assert call_kwargs["kwargs_arg1"] == "fake"
assert call_kwargs["_request_timeout"] == expected_timeout
assert out == "ok"


@pytest.fixture
def remove_default_conn(monkeypatch):
original_env_var = os.environ.get(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}")
Expand Down