diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index 41a99552a..7f809d14f 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -162,11 +162,11 @@ def build_worker_deployment_spec( return deployment_spec -def get_job_runner_pod_name(job_name): +def get_job_runner_job_name(job_name): return f"{job_name}-runner" -def build_job_pod_spec(job_name, cluster_name, namespace, spec, annotations, labels): +def build_job_spec(job_name, cluster_name, namespace, spec, annotations, labels): labels.update( **{ "dask.org/cluster-name": cluster_name, @@ -174,15 +174,21 @@ def build_job_pod_spec(job_name, cluster_name, namespace, spec, annotations, lab "sidecar.istio.io/inject": "false", } ) - pod_spec = { - "apiVersion": "v1", - "kind": "Pod", - "metadata": { - "name": get_job_runner_pod_name(job_name), - "labels": labels, - "annotations": annotations, + metadata = { + "name": get_job_runner_job_name(job_name), + "labels": labels, + "annotations": annotations, + } + job_spec = { + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": metadata, + "spec": { + "template": { + "metadata": metadata, + "spec": spec, + }, }, - "spec": spec, } env = [ { @@ -190,12 +196,12 @@ def build_job_pod_spec(job_name, cluster_name, namespace, spec, annotations, lab "value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786", }, ] - for i in range(len(pod_spec["spec"]["containers"])): - if "env" in pod_spec["spec"]["containers"][i]: - pod_spec["spec"]["containers"][i]["env"].extend(env) + for i in range(len(job_spec["spec"]["template"]["spec"]["containers"])): + if "env" in job_spec["spec"]["template"]["spec"]["containers"][i]: + job_spec["spec"]["template"]["spec"]["containers"][i]["env"].extend(env) else: - pod_spec["spec"]["containers"][i]["env"] = env - return pod_spec + job_spec["spec"]["template"]["spec"]["containers"][i]["env"] = env + return job_spec def build_default_worker_group_spec(cluster_name, spec, annotations, labels): @@ -737,28 +743,28 @@ async def daskjob_create_components( labels = _get_labels(meta) annotations = _get_annotations(meta) - job_spec = spec["job"] - if "metadata" in job_spec: - if "annotations" in job_spec["metadata"]: - annotations.update(**job_spec["metadata"]["annotations"]) - if "labels" in job_spec["metadata"]: - labels.update(**job_spec["metadata"]["labels"]) - job_pod_spec = build_job_pod_spec( + dask_job_spec = spec["job"] + if "metadata" in dask_job_spec: + if "annotations" in dask_job_spec["metadata"]: + annotations.update(**dask_job_spec["metadata"]["annotations"]) + if "labels" in dask_job_spec["metadata"]: + labels.update(**dask_job_spec["metadata"]["labels"]) + job_spec = build_job_spec( job_name=name, cluster_name=cluster_name, namespace=namespace, - spec=job_spec["spec"], + spec=dask_job_spec["spec"], annotations=annotations, labels=labels, ) - kopf.adopt(job_pod_spec) - await corev1api.create_namespaced_pod( + kopf.adopt(job_spec) + await kubernetes.client.BatchV1Api(api_client).create_namespaced_job( namespace=namespace, - body=job_pod_spec, + body=job_spec, ) patch.status["clusterName"] = cluster_name patch.status["jobStatus"] = "ClusterCreated" - patch.status["jobRunnerPodName"] = get_job_runner_pod_name(name) + patch.status["jobRunnerPodName"] = get_job_runner_job_name(name) @kopf.on.field( diff --git a/dask_kubernetes/operator/controller/tests/test_controller.py b/dask_kubernetes/operator/controller/tests/test_controller.py index d0adf0545..e10929f13 100644 --- a/dask_kubernetes/operator/controller/tests/test_controller.py +++ b/dask_kubernetes/operator/controller/tests/test_controller.py @@ -12,7 +12,7 @@ from kr8s.asyncio.objects import Pod, Deployment, Service from dask_kubernetes.operator.controller import ( KUBERNETES_DATETIME_FORMAT, - get_job_runner_pod_name, + get_job_runner_job_name, ) from dask_kubernetes.operator._objects import DaskCluster, DaskWorkerGroup, DaskJob @@ -425,13 +425,13 @@ def _assert_job_status_created(job_status): def _assert_job_status_cluster_created(job, job_status): assert "jobStatus" in job_status assert job_status["clusterName"] == job - assert job_status["jobRunnerPodName"] == get_job_runner_pod_name(job) + assert job_status["jobRunnerPodName"] == get_job_runner_job_name(job) def _assert_job_status_running(job, job_status): assert "jobStatus" in job_status assert job_status["clusterName"] == job - assert job_status["jobRunnerPodName"] == get_job_runner_pod_name(job) + assert job_status["jobRunnerPodName"] == get_job_runner_job_name(job) start_time = datetime.strptime(job_status["startTime"], KUBERNETES_DATETIME_FORMAT) assert datetime.utcnow() > start_time > (datetime.utcnow() - timedelta(seconds=10)) @@ -439,7 +439,7 @@ def _assert_job_status_running(job, job_status): def _assert_final_job_status(job, job_status, expected_status): assert job_status["jobStatus"] == expected_status assert job_status["clusterName"] == job - assert job_status["jobRunnerPodName"] == get_job_runner_pod_name(job) + assert job_status["jobRunnerPodName"] == get_job_runner_job_name(job) start_time = datetime.strptime(job_status["startTime"], KUBERNETES_DATETIME_FORMAT) assert datetime.utcnow() > start_time > (datetime.utcnow() - timedelta(minutes=1)) end_time = datetime.strptime(job_status["endTime"], KUBERNETES_DATETIME_FORMAT) @@ -459,8 +459,6 @@ async def test_job(k8s_cluster, kopf_runner, gen_job): async with gen_job("simplejob.yaml") as (job, ns): assert job - runner_name = f"{job}-runner" - # Assert that job was created while job not in k8s_cluster.kubectl( "get", "daskjobs.kubernetes.dask.org", "-n", ns @@ -480,6 +478,10 @@ async def test_job(k8s_cluster, kopf_runner, gen_job): job_status = _get_job_status(k8s_cluster, ns) _assert_job_status_cluster_created(job, job_status) + # Assert job is created + while job not in k8s_cluster.kubectl("get", "jobs", "-n", ns): + await asyncio.sleep(0.1) + # Assert job pod is created while job not in k8s_cluster.kubectl("get", "po", "-n", ns): await asyncio.sleep(0.1) @@ -507,7 +509,12 @@ async def test_job(k8s_cluster, kopf_runner, gen_job): # Assert job pod runs to completion (will fail if doesn't connect to cluster) while "Completed" not in k8s_cluster.kubectl( - "get", "-n", ns, "po", runner_name + "get", + "-n", + ns, + "po", + "-l", + "dask.org/component=job-runner", ): await asyncio.sleep(0.1) @@ -530,8 +537,6 @@ async def test_failed_job(k8s_cluster, kopf_runner, gen_job): async with gen_job("failedjob.yaml") as (job, ns): assert job - runner_name = f"{job}-runner" - # Assert that job was created while job not in k8s_cluster.kubectl( "get", "daskjobs.kubernetes.dask.org", "-n", ns @@ -551,6 +556,10 @@ async def test_failed_job(k8s_cluster, kopf_runner, gen_job): job_status = _get_job_status(k8s_cluster, ns) _assert_job_status_cluster_created(job, job_status) + # Assert job is created + while job not in k8s_cluster.kubectl("get", "jobs", "-n", ns): + await asyncio.sleep(0.1) + # Assert job pod is created while job not in k8s_cluster.kubectl("get", "po", "-n", ns): await asyncio.sleep(0.1) @@ -565,7 +574,12 @@ async def test_failed_job(k8s_cluster, kopf_runner, gen_job): # Assert job pod runs to failure while "Error" not in k8s_cluster.kubectl( - "get", "po", "-n", ns, runner_name + "get", + "po", + "-n", + ns, + "-l", + "dask.org/component=job-runner", ): await asyncio.sleep(0.1)