Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Kubernetes job watch exit when no timeout given #8350

Merged
merged 5 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions src/prefect/infrastructure/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import copy
import enum
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import anyio.abc
import pendulum
import yaml
from pydantic import Field, root_validator, validator
from typing_extensions import Literal
Expand Down Expand Up @@ -602,32 +602,35 @@ def _watch_job(self, job_name: str) -> int:
)

self.logger.debug(f"Job {job_name!r}: Starting watch for job completion")
start_time = pendulum.now("utc")
deadline = (
(time.time() + self.job_watch_timeout_seconds)
if self.job_watch_timeout_seconds is not None
else None
)
completed = False
while not completed:
elapsed = (pendulum.now("utc") - start_time).in_seconds()
if (
self.job_watch_timeout_seconds is not None
and elapsed > self.job_watch_timeout_seconds
):
self.logger.error(f"Job {job_name!r}: Job timed out after {elapsed}s.")
remaining_time = deadline - time.time() if deadline else None
if deadline and remaining_time <= 0:
self.logger.error(
f"Job {job_name!r}: Job did not complete within "
f"timeout of {self.job_watch_timeout_seconds}s."
)
return -1

watch = kubernetes.watch.Watch()
with self.get_batch_client() as batch_client:
remaining_timeout = (
( # subtract previous watch time
self.job_watch_timeout_seconds - elapsed
)
if self.job_watch_timeout_seconds
else None
# The kubernetes library will disable retries if the timeout kwarg is
# present regardless of the value so we do not pass it unless given
# https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160
timeout_seconds = (
{"timeout_seconds": remaining_time} if deadline else {}
)

for event in watch.stream(
func=batch_client.list_namespaced_job,
field_selector=f"metadata.name={job_name}",
namespace=self.namespace,
timeout_seconds=remaining_timeout,
**timeout_seconds,
):
if event["object"].status.completion_time:
if not event["object"].status.succeeded:
Expand All @@ -636,12 +639,6 @@ def _watch_job(self, job_name: str) -> int:
completed = True
watch.stop()
break
else:
self.logger.error(
f"Job {job_name!r}: Job did not complete within "
f"timeout of {self.job_watch_timeout_seconds}s."
)
return -1

with self.get_client() as client:
pod_status = client.read_namespaced_pod_status(
Expand Down
134 changes: 130 additions & 4 deletions tests/infrastructure/test_kubernetes_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def test_uses_cluster_config_if_not_in_cluster(
mock_cluster_config.load_kube_config.assert_called_once()


@pytest.mark.parametrize("job_timeout", [24, None])
@pytest.mark.parametrize("job_timeout", [24, 100])
def test_allows_configurable_timeouts_for_pod_and_job_watches(
mock_k8s_client,
mock_watch,
Expand All @@ -728,9 +728,17 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches(
command=["echo", "hello"],
pod_watch_timeout_seconds=42,
)
expected_job_call_kwargs = dict(
func=mock_k8s_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector=mock.ANY,
)

if job_timeout is not None:
k8s_job_args["job_watch_timeout_seconds"] = job_timeout
expected_job_call_kwargs["timeout_seconds"] = pytest.approx(
job_timeout, abs=0.01
)

KubernetesJob(**k8s_job_args).run(MagicMock())

Expand All @@ -742,11 +750,41 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches(
label_selector=mock.ANY,
timeout_seconds=42,
),
mock.call(**expected_job_call_kwargs),
]
)


@pytest.mark.parametrize("job_timeout", [None])
def test_excludes_timeout_from_job_watches_when_null(
mock_k8s_client,
mock_watch,
mock_k8s_batch_client,
job_timeout,
):
mock_watch.stream = mock.Mock(
side_effect=_mock_pods_stream_that_returns_running_pod
)
k8s_job_args = dict(
command=["echo", "hello"],
job_watch_timeout_seconds=job_timeout,
)

KubernetesJob(**k8s_job_args).run(MagicMock())

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_k8s_client.list_namespaced_pod,
namespace=mock.ANY,
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector=mock.ANY,
timeout_seconds=job_timeout,
# Note: timeout_seconds is excluded here
),
]
)
Expand All @@ -771,13 +809,12 @@ def test_watches_the_right_namespace(
func=mock_k8s_client.list_namespaced_pod,
namespace="my-awesome-flows",
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
timeout_seconds=60,
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
namespace="my-awesome-flows",
field_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
]
)
Expand Down Expand Up @@ -828,6 +865,95 @@ def mock_stream(*args, **kwargs):
assert result.status_code == -1


def test_watch_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Yield the job then return exiting the stream
# After restarting the watch a few times, we'll report completion
job.status.completion_time = (
None if mock_watch.stream.call_count < 3 else True
)
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"]).run(MagicMock())
assert result.status_code == 1
assert mock_watch.stream.call_count == 3


def test_watch_timeout_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Sleep a little
sleep(0.25)

# Yield the job then return exiting the stream
job.status.completion_time = None
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"], job_watch_timeout_seconds=1).run(
MagicMock()
)
assert result.status_code == -1

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_k8s_client.list_namespaced_pod,
namespace=mock.ANY,
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
# Starts with the full timeout
# Approximate comparisons are needed since executing code takes some time
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(1, abs=0.01),
),
# Then, elapsed time removed on each call
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.75, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.5, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.25, abs=0.05),
),
]
)


class TestCustomizingBaseJob:
"""Tests scenarios where a user is providing a customized base Job template"""

Expand Down