Skip to content

Commit

Permalink
Fix Kubernetes job watch exit when no timeout given (#8350)
Browse files Browse the repository at this point in the history
Co-authored-by: Nathan Nowack <thrast36@gmail.com>
  • Loading branch information
zanieb and zzstoatzz committed Feb 6, 2023
1 parent 66cec2b commit aa41652
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/prefect/infrastructure/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import anyio.abc
import pendulum
import yaml
from pydantic import Field, root_validator, validator
from typing_extensions import Literal
Expand Down
89 changes: 89 additions & 0 deletions tests/infrastructure/test_kubernetes_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,95 @@ def mock_stream(*args, **kwargs):
assert result.status_code == -1


def test_watch_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Yield the job then return exiting the stream
# After restarting the watch a few times, we'll report completion
job.status.completion_time = (
None if mock_watch.stream.call_count < 3 else True
)
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"]).run(MagicMock())
assert result.status_code == 1
assert mock_watch.stream.call_count == 3


def test_watch_timeout_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Sleep a little
sleep(0.25)

# Yield the job then return exiting the stream
job.status.completion_time = None
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"], job_watch_timeout_seconds=1).run(
MagicMock()
)
assert result.status_code == -1

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_k8s_client.list_namespaced_pod,
namespace=mock.ANY,
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
# Starts with the full timeout
# Approximate comparisons are needed since executing code takes some time
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(1, abs=0.01),
),
# Then, elapsed time removed on each call
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.75, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.5, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.25, abs=0.05),
),
]
)


class TestCustomizingBaseJob:
"""Tests scenarios where a user is providing a customized base Job template"""

Expand Down

0 comments on commit aa41652

Please sign in to comment.