diff --git a/src/prefect/infrastructure/kubernetes.py b/src/prefect/infrastructure/kubernetes.py index 6e09e2e8f1e0..b451994cd24f 100644 --- a/src/prefect/infrastructure/kubernetes.py +++ b/src/prefect/infrastructure/kubernetes.py @@ -1,11 +1,11 @@ import copy import enum import os +import time from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import anyio.abc -import pendulum import yaml from pydantic import Field, root_validator, validator from typing_extensions import Literal @@ -602,32 +602,35 @@ def _watch_job(self, job_name: str) -> int: ) self.logger.debug(f"Job {job_name!r}: Starting watch for job completion") - start_time = pendulum.now("utc") + deadline = ( + (time.time() + self.job_watch_timeout_seconds) + if self.job_watch_timeout_seconds is not None + else None + ) completed = False while not completed: - elapsed = (pendulum.now("utc") - start_time).in_seconds() - if ( - self.job_watch_timeout_seconds is not None - and elapsed > self.job_watch_timeout_seconds - ): - self.logger.error(f"Job {job_name!r}: Job timed out after {elapsed}s.") + remaining_time = deadline - time.time() if deadline else None + if deadline and remaining_time <= 0: + self.logger.error( + f"Job {job_name!r}: Job did not complete within " + f"timeout of {self.job_watch_timeout_seconds}s." + ) return -1 watch = kubernetes.watch.Watch() with self.get_batch_client() as batch_client: - remaining_timeout = ( - ( # subtract previous watch time - self.job_watch_timeout_seconds - elapsed - ) - if self.job_watch_timeout_seconds - else None + # The kubernetes library will disable retries if the timeout kwarg is + # present regardless of the value so we do not pass it unless given + # https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160 + timeout_seconds = ( + {"timeout_seconds": remaining_time} if deadline else {} ) for event in watch.stream( func=batch_client.list_namespaced_job, field_selector=f"metadata.name={job_name}", namespace=self.namespace, - timeout_seconds=remaining_timeout, + **timeout_seconds, ): if event["object"].status.completion_time: if not event["object"].status.succeeded: @@ -636,12 +639,6 @@ def _watch_job(self, job_name: str) -> int: completed = True watch.stop() break - else: - self.logger.error( - f"Job {job_name!r}: Job did not complete within " - f"timeout of {self.job_watch_timeout_seconds}s." - ) - return -1 with self.get_client() as client: pod_status = client.read_namespaced_pod_status( diff --git a/tests/infrastructure/test_kubernetes_job.py b/tests/infrastructure/test_kubernetes_job.py index df26672a5f7b..3d9f72f2a995 100644 --- a/tests/infrastructure/test_kubernetes_job.py +++ b/tests/infrastructure/test_kubernetes_job.py @@ -714,7 +714,7 @@ def test_uses_cluster_config_if_not_in_cluster( mock_cluster_config.load_kube_config.assert_called_once() -@pytest.mark.parametrize("job_timeout", [24, None]) +@pytest.mark.parametrize("job_timeout", [24, 100]) def test_allows_configurable_timeouts_for_pod_and_job_watches( mock_k8s_client, mock_watch, @@ -728,9 +728,17 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches( command=["echo", "hello"], pod_watch_timeout_seconds=42, ) + expected_job_call_kwargs = dict( + func=mock_k8s_batch_client.list_namespaced_job, + namespace=mock.ANY, + field_selector=mock.ANY, + ) if job_timeout is not None: k8s_job_args["job_watch_timeout_seconds"] = job_timeout + expected_job_call_kwargs["timeout_seconds"] = pytest.approx( + job_timeout, abs=0.01 + ) KubernetesJob(**k8s_job_args).run(MagicMock()) @@ -742,11 +750,41 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches( label_selector=mock.ANY, timeout_seconds=42, ), + mock.call(**expected_job_call_kwargs), + ] + ) + + +@pytest.mark.parametrize("job_timeout", [None]) +def test_excludes_timeout_from_job_watches_when_null( + mock_k8s_client, + mock_watch, + mock_k8s_batch_client, + job_timeout, +): + mock_watch.stream = mock.Mock( + side_effect=_mock_pods_stream_that_returns_running_pod + ) + k8s_job_args = dict( + command=["echo", "hello"], + job_watch_timeout_seconds=job_timeout, + ) + + KubernetesJob(**k8s_job_args).run(MagicMock()) + + mock_watch.stream.assert_has_calls( + [ + mock.call( + func=mock_k8s_client.list_namespaced_pod, + namespace=mock.ANY, + label_selector=mock.ANY, + timeout_seconds=mock.ANY, + ), mock.call( func=mock_k8s_batch_client.list_namespaced_job, namespace=mock.ANY, field_selector=mock.ANY, - timeout_seconds=job_timeout, + # Note: timeout_seconds is excluded here ), ] ) @@ -771,13 +809,12 @@ def test_watches_the_right_namespace( func=mock_k8s_client.list_namespaced_pod, namespace="my-awesome-flows", label_selector=mock.ANY, - timeout_seconds=mock.ANY, + timeout_seconds=60, ), mock.call( func=mock_k8s_batch_client.list_namespaced_job, namespace="my-awesome-flows", field_selector=mock.ANY, - timeout_seconds=mock.ANY, ), ] ) @@ -828,6 +865,95 @@ def mock_stream(*args, **kwargs): assert result.status_code == -1 +def test_watch_is_restarted_until_job_is_complete( + mock_k8s_client, mock_watch, mock_k8s_batch_client +): + def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_k8s_client.list_namespaced_pod: + job_pod = MagicMock(spec=kubernetes.client.V1Pod) + job_pod.status.phase = "Running" + yield {"object": job_pod} + + if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job: + job = MagicMock(spec=kubernetes.client.V1Job) + + # Yield the job then return exiting the stream + # After restarting the watch a few times, we'll report completion + job.status.completion_time = ( + None if mock_watch.stream.call_count < 3 else True + ) + yield {"object": job} + + mock_watch.stream.side_effect = mock_stream + result = KubernetesJob(command=["echo", "hello"]).run(MagicMock()) + assert result.status_code == 1 + assert mock_watch.stream.call_count == 3 + + +def test_watch_timeout_is_restarted_until_job_is_complete( + mock_k8s_client, mock_watch, mock_k8s_batch_client +): + def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_k8s_client.list_namespaced_pod: + job_pod = MagicMock(spec=kubernetes.client.V1Pod) + job_pod.status.phase = "Running" + yield {"object": job_pod} + + if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job: + job = MagicMock(spec=kubernetes.client.V1Job) + + # Sleep a little + sleep(0.25) + + # Yield the job then return exiting the stream + job.status.completion_time = None + yield {"object": job} + + mock_watch.stream.side_effect = mock_stream + result = KubernetesJob(command=["echo", "hello"], job_watch_timeout_seconds=1).run( + MagicMock() + ) + assert result.status_code == -1 + + mock_watch.stream.assert_has_calls( + [ + mock.call( + func=mock_k8s_client.list_namespaced_pod, + namespace=mock.ANY, + label_selector=mock.ANY, + timeout_seconds=mock.ANY, + ), + # Starts with the full timeout + # Approximate comparisons are needed since executing code takes some time + mock.call( + func=mock_k8s_batch_client.list_namespaced_job, + field_selector=mock.ANY, + namespace=mock.ANY, + timeout_seconds=pytest.approx(1, abs=0.01), + ), + # Then, elapsed time removed on each call + mock.call( + func=mock_k8s_batch_client.list_namespaced_job, + field_selector=mock.ANY, + namespace=mock.ANY, + timeout_seconds=pytest.approx(0.75, abs=0.05), + ), + mock.call( + func=mock_k8s_batch_client.list_namespaced_job, + field_selector=mock.ANY, + namespace=mock.ANY, + timeout_seconds=pytest.approx(0.5, abs=0.05), + ), + mock.call( + func=mock_k8s_batch_client.list_namespaced_job, + field_selector=mock.ANY, + namespace=mock.ANY, + timeout_seconds=pytest.approx(0.25, abs=0.05), + ), + ] + ) + + class TestCustomizingBaseJob: """Tests scenarios where a user is providing a customized base Job template"""