From 3568614c0c1bf3b2d16b7761468849c4f915c8b7 Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Thu, 23 Mar 2023 16:11:51 +0200 Subject: [PATCH 1/4] [jobframework] Improve node selectors tracking Save and try to restore the original node selectors in/from a job annotation "kueue.x-k8s.io/original-selectors". --- pkg/controller/jobframework/constants.go | 6 ++ pkg/controller/jobframework/interface.go | 2 +- pkg/controller/jobframework/reconciler.go | 80 +++++++++++++++++-- pkg/controller/jobs/job/job_controller.go | 6 +- .../jobs/mpijob/mpijob_controller.go | 5 +- 5 files changed, 87 insertions(+), 12 deletions(-) diff --git a/pkg/controller/jobframework/constants.go b/pkg/controller/jobframework/constants.go index 8142b183e8..8fbbec5a1d 100644 --- a/pkg/controller/jobframework/constants.go +++ b/pkg/controller/jobframework/constants.go @@ -32,4 +32,10 @@ const ( // ignores this Job from admission, and takes control of its suspension // status based on the admission status of the parent workload. ParentWorkloadAnnotation = "kueue.x-k8s.io/parent-workload" + + // OriginalNodeSelectorsAnnotation is the annotation in which the original + // node selectors are recorded upon a workload admission. This information + // will be used to restore them when the job is suspended. + // The content is a json marshaled slice of selectors. + OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors" ) diff --git a/pkg/controller/jobframework/interface.go b/pkg/controller/jobframework/interface.go index 7c6f92fbe2..e80c2c0490 100644 --- a/pkg/controller/jobframework/interface.go +++ b/pkg/controller/jobframework/interface.go @@ -34,7 +34,7 @@ type GenericJob interface { // RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job. RunWithNodeAffinity(nodeSelectors []map[string]string) // RestoreNodeAffinity will restore the original node affinity of job. - RestoreNodeAffinity(podSets []kueue.PodSet) + RestoreNodeAffinity(nodeSelectors []map[string]string) // Finished means whether the job is completed/failed or not, // condition represents the workload finished condition. Finished() (condition metav1.Condition, finished bool) diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index eb9a816f89..6afe388f85 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -15,6 +15,7 @@ package jobframework import ( "context" + "encoding/json" "fmt" corev1 "k8s.io/api/core/v1" @@ -34,6 +35,10 @@ import ( "sigs.k8s.io/kueue/pkg/workload" ) +var ( + errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation) +) + // JobReconciler reconciles a GenericJob object type JobReconciler struct { client client.Client @@ -300,7 +305,13 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec // startJob will unsuspend the job, and also inject the node affinity. func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error { - nodeSelectors, err := r.getNodeSelectors(ctx, wl) + //get the original selectors and store them in the job object + originalSelectors := r.getNodeSelectorsFromPodSets(wl) + if err := nodeSelectorsSetToObject(object, originalSelectors); err != nil { + return fmt.Errorf("startJob, record original node selectors: %w", err) + } + + nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl) if err != nil { return err } @@ -318,6 +329,7 @@ func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object cli // stopJob will suspend the job, and also restore node affinity, reset job status if needed. func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload, eventMsg string) error { + log := ctrl.LoggerFrom(ctx) // Suspend the job at first then we're able to update the scheduling directives. job.Suspend() @@ -333,8 +345,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie } } - if wl != nil { - job.RestoreNodeAffinity(wl.Spec.PodSets) + log.V(3).Info("restore node selectors from annotation") + selectors, err := getNodeSelectorsFromObjectAnnotation(object) + if err != nil { + log.V(3).Error(err, "Unable to get original node selectors") + } else { + job.RestoreNodeAffinity(selectors) return r.client.Update(ctx, object) } @@ -369,8 +385,8 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o return wl, nil } -// getNodeSelectors will extract node selectors from admitted workloads. -func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) { +// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads. +func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) { if len(w.Status.Admission.PodSetAssignments) == 0 { return nil, nil } @@ -401,6 +417,19 @@ func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload) return nodeSelectors, nil } +// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets. +func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []map[string]string { + podSets := w.Spec.PodSets + if len(podSets) == 0 { + return nil + } + ret := make([]map[string]string, len(podSets)) + for psi := range podSets { + ret[psi] = cloneNodeSelector(podSets[psi].Template.Spec.NodeSelector) + } + return ret +} + func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job GenericJob, object client.Object) error { log := ctrl.LoggerFrom(ctx) @@ -442,3 +471,44 @@ func generatePodsReadyCondition(job GenericJob, wl *kueue.Workload) metav1.Condi Message: message, } } + +func cloneNodeSelector(src map[string]string) map[string]string { + ret := make(map[string]string, len(src)) + for k, v := range src { + ret[k] = v + } + return ret +} + +// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the +// object's annotations fails if it's not found or is unable to unmarshal +func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]map[string]string, error) { + str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] + if !found { + return nil, errNodeSelectorsNotFound + } + // unmarshal + ret := []map[string]string{} + if err := json.Unmarshal([]byte(str), &ret); err != nil { + return nil, err + } + return ret, nil +} + +// nodeSelectorsSetToObject - sets an annotation containing the provided node selectors into +// a job object, even if very unlikely it could return an error related to json.marshaling +func nodeSelectorsSetToObject(obj client.Object, nodeSelectors []map[string]string) error { + nodeSelectorsBytes, err := json.Marshal(nodeSelectors) + if err != nil { + return err + } + + annotations := obj.GetAnnotations() + if annotations == nil { + annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)} + } else { + annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes) + } + obj.SetAnnotations(annotations) + return nil +} diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index 3b19563dfd..5e9234beb8 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -165,14 +165,14 @@ func (j *Job) RunWithNodeAffinity(nodeSelectors []map[string]string) { } } -func (j *Job) RestoreNodeAffinity(podSets []kueue.PodSet) { - if len(podSets) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSets[0].Template.Spec.NodeSelector) { +func (j *Job) RestoreNodeAffinity(nodeSelectors []map[string]string) { + if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0]) { return } j.Spec.Template.Spec.NodeSelector = map[string]string{} - for k, v := range podSets[0].Template.Spec.NodeSelector { + for k, v := range nodeSelectors[0] { j.Spec.Template.Spec.NodeSelector[k] = v } } diff --git a/pkg/controller/jobs/mpijob/mpijob_controller.go b/pkg/controller/jobs/mpijob/mpijob_controller.go index e4f0ece595..8d4b0ebc33 100644 --- a/pkg/controller/jobs/mpijob/mpijob_controller.go +++ b/pkg/controller/jobs/mpijob/mpijob_controller.go @@ -125,11 +125,10 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) { } } -func (j *MPIJob) RestoreNodeAffinity(podSets []kueue.PodSet) { +func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []map[string]string) { orderedReplicaTypes := orderedReplicaTypes(&j.Spec) - for index := range podSets { + for index, nodeSelector := range nodeSelectors { replicaType := orderedReplicaTypes[index] - nodeSelector := podSets[index].Template.Spec.NodeSelector if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{} for k, v := range nodeSelector { From 79a287152992a09c82e9d7f663139f0385d45d70 Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Fri, 24 Mar 2023 15:19:09 +0200 Subject: [PATCH 2/4] [test/integration/jobs] Check selectors restoration on workload deletion --- .../controller/job/job_controller_test.go | 59 ++++++++++++++++ .../mpijob/mpijob_controller_test.go | 70 +++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/test/integration/controller/job/job_controller_test.go b/test/integration/controller/job/job_controller_test.go index 1f321d3be8..0641b82fe4 100644 --- a/test/integration/controller/job/job_controller_test.go +++ b/test/integration/controller/job/job_controller_test.go @@ -831,4 +831,63 @@ var _ = ginkgo.Describe("Job controller interacting with scheduler", func() { return createdProdJob.Spec.Suspend }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) }) + + ginkgo.When("The workload is deleted while it's admitted", func() { + ginkgo.It("Should restore the original node selectors", func() { + localQueue := testing.MakeLocalQueue("local-queue", ns.Name).ClusterQueue(prodClusterQ.Name).Obj() + job := testingjob.MakeJob(jobName, ns.Name).Queue(localQueue.Name).Request(corev1.ResourceCPU, "2").Obj() + lookupKey := types.NamespacedName{Name: job.Name, Namespace: job.Namespace} + createdJob := &batchv1.Job{} + + ginkgo.By("create a job", func() { + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("job should be suspend", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(true))) + }) + + // backup the the podSet's node selector + originalNodeSelector := createdJob.Spec.Template.Spec.NodeSelector + + ginkgo.By("create a localQueue", func() { + gomega.Expect(k8sClient.Create(ctx, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("job should be unsuspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) + }) + + ginkgo.By("the node selector should be updated", func() { + gomega.Eventually(func() map[string]string { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Template.Spec.NodeSelector + }, util.Timeout, util.Interval).ShouldNot(gomega.Equal(originalNodeSelector)) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("delete the workload to stop the job", func() { + wl := &kueue.Workload{} + wlKey := types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(job.Name), Namespace: job.Namespace} + gomega.Expect(k8sClient.Get(ctx, wlKey, wl)).Should(gomega.Succeed()) + gomega.Expect(util.DeleteWorkload(ctx, k8sClient, wl)).Should(gomega.Succeed()) + }) + + ginkgo.By("the node selector should be restored", func() { + gomega.Eventually(func() map[string]string { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Template.Spec.NodeSelector + }, util.Timeout, util.Interval).Should(gomega.Equal(originalNodeSelector)) + }) + }) + }) }) diff --git a/test/integration/controller/mpijob/mpijob_controller_test.go b/test/integration/controller/mpijob/mpijob_controller_test.go index 427a388c3b..a02c807aa0 100644 --- a/test/integration/controller/mpijob/mpijob_controller_test.go +++ b/test/integration/controller/mpijob/mpijob_controller_test.go @@ -561,4 +561,74 @@ var _ = ginkgo.Describe("Job controller interacting with scheduler", func() { }) + ginkgo.When("The workload is deleted while it's admitted", func() { + ginkgo.It("Should restore the original node selectors", func() { + + localQueue := testing.MakeLocalQueue("local-queue", ns.Name).ClusterQueue(clusterQueue.Name).Obj() + job := testingmpijob.MakeMPIJob(jobName, ns.Name).Queue(localQueue.Name). + Request(kubeflow.MPIReplicaTypeLauncher, corev1.ResourceCPU, "3"). + Request(kubeflow.MPIReplicaTypeWorker, corev1.ResourceCPU, "4"). + Obj() + lookupKey := types.NamespacedName{Name: job.Name, Namespace: job.Namespace} + createdJob := &kubeflow.MPIJob{} + + nodeSelectors := func(j *kubeflow.MPIJob) map[kubeflow.MPIReplicaType]map[string]string { + ret := map[kubeflow.MPIReplicaType]map[string]string{} + for k := range j.Spec.MPIReplicaSpecs { + ret[k] = j.Spec.MPIReplicaSpecs[k].Template.Spec.NodeSelector + } + return ret + } + + ginkgo.By("create a job", func() { + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + }) + + ginkgo.By("job should be suspend", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(true))) + }) + + // backup the the node selectors + originalNodeSelectors := nodeSelectors(createdJob) + + ginkgo.By("create a localQueue", func() { + gomega.Expect(k8sClient.Create(ctx, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("job should be unsuspended", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) + }) + + ginkgo.By("the node selectors should be updated", func() { + gomega.Eventually(func() map[kubeflow.MPIReplicaType]map[string]string { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return nodeSelectors(createdJob) + }, util.Timeout, util.Interval).ShouldNot(gomega.Equal(originalNodeSelectors)) + }) + + ginkgo.By("delete the localQueue to prevent readmission", func() { + gomega.Expect(util.DeleteLocalQueue(ctx, k8sClient, localQueue)).Should(gomega.Succeed()) + }) + + ginkgo.By("delete the workload to stop the job", func() { + wl := &kueue.Workload{} + wlKey := types.NamespacedName{Name: workloadmpijob.GetWorkloadNameForMPIJob(job.Name), Namespace: job.Namespace} + gomega.Expect(k8sClient.Get(ctx, wlKey, wl)).Should(gomega.Succeed()) + gomega.Expect(util.DeleteWorkload(ctx, k8sClient, wl)).Should(gomega.Succeed()) + }) + + ginkgo.By("the node selectors should be restored", func() { + gomega.Eventually(func() map[kubeflow.MPIReplicaType]map[string]string { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return nodeSelectors(createdJob) + }, util.Timeout, util.Interval).Should(gomega.Equal(originalNodeSelectors)) + }) + }) + }) }) From 1022e23763c25fa015d327676da36f3b3590d140 Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Wed, 29 Mar 2023 10:51:29 +0300 Subject: [PATCH 3/4] [jobframework] Validate original node selectors --- pkg/controller/jobframework/validation.go | 19 +++++ pkg/controller/jobs/job/job_webhook.go | 1 + pkg/controller/jobs/job/job_webhook_test.go | 30 ++++++++ pkg/controller/jobs/mpijob/mpijob_webhook.go | 6 +- .../jobs/mpijob/mpijob_webhook_test.go | 77 +++++++++++++++++++ pkg/util/testingjobs/job/wrappers.go | 5 ++ .../testingjobs/mpijob/wrappers_mpijob.go | 12 +++ 7 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 pkg/controller/jobs/mpijob/mpijob_webhook_test.go diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index 3ae56e6332..c8c67d7652 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -14,6 +14,7 @@ limitations under the License. package jobframework import ( + "encoding/json" "strings" apivalidation "k8s.io/apimachinery/pkg/api/validation" @@ -26,6 +27,8 @@ var ( labelsPath = field.NewPath("metadata", "labels") parentWorkloadKeyPath = annotationsPath.Key(ParentWorkloadAnnotation) queueNameLabelPath = labelsPath.Key(QueueLabel) + + originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation) ) func ValidateCreateForQueueName(job GenericJob) field.ErrorList { @@ -71,3 +74,19 @@ func ValidateUpdateForParentWorkload(oldJob, newJob GenericJob) field.ErrorList } return allErrs } + +func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.ErrorList { + var allErrs field.ErrorList + if oldJob.IsSuspended() == newJob.IsSuspended() { + if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation], + newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation], originalNodeSelectorsWorkloadKeyPath); len(errList) > 0 { + allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) + } + } else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found { + out := []map[string]string{} + if err := json.Unmarshal([]byte(av), &out); err != nil { + allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error())) + } + } + return allErrs +} diff --git a/pkg/controller/jobs/job/job_webhook.go b/pkg/controller/jobs/job/job_webhook.go index adf569835d..41d56fcc7b 100644 --- a/pkg/controller/jobs/job/job_webhook.go +++ b/pkg/controller/jobs/job/job_webhook.go @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime. func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList { allErrs := validateCreate(newJob) allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...) return allErrs } diff --git a/pkg/controller/jobs/job/job_webhook_test.go b/pkg/controller/jobs/job/job_webhook_test.go index e627785102..b605e62bde 100644 --- a/pkg/controller/jobs/job/job_webhook_test.go +++ b/pkg/controller/jobs/job/job_webhook_test.go @@ -38,6 +38,8 @@ var ( parentWorkloadKeyPath = annotationsPath.Key(jobframework.ParentWorkloadAnnotation) queueNameLabelPath = labelsPath.Key(jobframework.QueueLabel) queueNameAnnotationsPath = annotationsPath.Key(jobframework.QueueAnnotation) + + originalNodeSelectorsKeyPath = annotationsPath.Key(jobframework.OriginalNodeSelectorsAnnotation) ) func TestValidateCreate(t *testing.T) { @@ -156,6 +158,34 @@ func TestValidateUpdate(t *testing.T) { field.Forbidden(parentWorkloadKeyPath, "this annotation is immutable"), }, }, + { + name: "original node selectors can be set while unsuspending", + oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + wantErr: nil, + }, + { + name: "original node selectors can be set while suspending", + oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + wantErr: nil, + }, + { + name: "immutable original node selectors while not suspended", + oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(), + wantErr: field.ErrorList{ + field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), + }, + }, + { + name: "immutable original node selectors while suspended", + oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(), + wantErr: field.ErrorList{ + field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), + }, + }, } for _, tc := range testcases { diff --git a/pkg/controller/jobs/mpijob/mpijob_webhook.go b/pkg/controller/jobs/mpijob/mpijob_webhook.go index 34a7de0d69..21b859c705 100644 --- a/pkg/controller/jobs/mpijob/mpijob_webhook.go +++ b/pkg/controller/jobs/mpijob/mpijob_webhook.go @@ -82,10 +82,14 @@ func validateCreate(job jobframework.GenericJob) field.ErrorList { // ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error { oldJob := oldObj.(*kubeflow.MPIJob) + oldGenJob := &MPIJob{*oldJob} newJob := newObj.(*kubeflow.MPIJob) + newGenJob := &MPIJob{*newJob} log := ctrl.LoggerFrom(ctx).WithName("job-webhook") log.Info("Validating update", "job", klog.KObj(newJob)) - return jobframework.ValidateUpdateForQueueName(&MPIJob{*oldJob}, &MPIJob{*newJob}).ToAggregate() + allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...) + return allErrs.ToAggregate() } // ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type diff --git a/pkg/controller/jobs/mpijob/mpijob_webhook_test.go b/pkg/controller/jobs/mpijob/mpijob_webhook_test.go new file mode 100644 index 0000000000..5e6291fdd4 --- /dev/null +++ b/pkg/controller/jobs/mpijob/mpijob_webhook_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + "k8s.io/apimachinery/pkg/util/validation/field" + + "sigs.k8s.io/kueue/pkg/controller/jobframework" + testingutil "sigs.k8s.io/kueue/pkg/util/testingjobs/mpijob" +) + +var ( + originalNodeSelectorsKeyPath = field.NewPath("metadata", "annotations").Key(jobframework.OriginalNodeSelectorsAnnotation) +) + +func TestUpdate(t *testing.T) { + testcases := map[string]struct { + oldJob *kubeflow.MPIJob + newJob *kubeflow.MPIJob + wantErr error + }{ + "original node selectors can be set while unsuspending": { + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + wantErr: nil, + }, + "original node selectors can be set while suspending": { + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + wantErr: nil, + }, + "immutable original node selectors while not suspended": { + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(), + wantErr: field.ErrorList{ + field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), + }.ToAggregate(), + }, + "immutable original node selectors while suspended": { + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(), + wantErr: field.ErrorList{ + field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), + }.ToAggregate(), + }, + } + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + wh := &MPIJobWebhook{} + result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob) + + if diff := cmp.Diff(tc.wantErr, result); diff != "" { + t.Errorf("ValidateUpdate() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/util/testingjobs/job/wrappers.go b/pkg/util/testingjobs/job/wrappers.go index 2085a9c3f1..64dd9bf4ec 100644 --- a/pkg/util/testingjobs/job/wrappers.go +++ b/pkg/util/testingjobs/job/wrappers.go @@ -102,6 +102,11 @@ func (j *JobWrapper) ParentWorkload(parentWorkload string) *JobWrapper { return j } +func (j *JobWrapper) OriginalNodeSelectorsAnnotation(content string) *JobWrapper { + j.Annotations[jobframework.OriginalNodeSelectorsAnnotation] = content + return j +} + // Toleration adds a toleration to the job. func (j *JobWrapper) Toleration(t corev1.Toleration) *JobWrapper { j.Spec.Template.Spec.Tolerations = append(j.Spec.Template.Spec.Tolerations, t) diff --git a/pkg/util/testingjobs/mpijob/wrappers_mpijob.go b/pkg/util/testingjobs/mpijob/wrappers_mpijob.go index 2ea8020602..602ce6a6d7 100644 --- a/pkg/util/testingjobs/mpijob/wrappers_mpijob.go +++ b/pkg/util/testingjobs/mpijob/wrappers_mpijob.go @@ -116,3 +116,15 @@ func (j *MPIJobWrapper) Parallelism(p int32) *MPIJobWrapper { j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Replicas = pointer.Int32(p) return j } + +// OriginalNodeSelectorsAnnotation updates the original node selectors annotation +func (j *MPIJobWrapper) OriginalNodeSelectorsAnnotation(content string) *MPIJobWrapper { + j.Annotations[jobframework.OriginalNodeSelectorsAnnotation] = content + return j +} + +// Suspend updates the suspend status of the job +func (j *MPIJobWrapper) Suspend(s bool) *MPIJobWrapper { + j.Spec.RunPolicy.Suspend = &s + return j +} From 776aefa8073116f645859cacc2ee1b628c9b7722 Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Mon, 3 Apr 2023 12:02:21 +0300 Subject: [PATCH 4/4] [jobframework] Record podSet name in originalNodeSelectors annotation --- pkg/controller/jobframework/interface.go | 4 +-- pkg/controller/jobframework/reconciler.go | 36 ++++++++++++------- pkg/controller/jobframework/validation.go | 2 +- pkg/controller/jobs/job/job_controller.go | 12 +++---- pkg/controller/jobs/job/job_webhook_test.go | 21 ++++++++--- .../jobs/mpijob/mpijob_controller.go | 14 ++++---- .../jobs/mpijob/mpijob_webhook_test.go | 20 ++++++++--- 7 files changed, 72 insertions(+), 37 deletions(-) diff --git a/pkg/controller/jobframework/interface.go b/pkg/controller/jobframework/interface.go index e80c2c0490..e919125da0 100644 --- a/pkg/controller/jobframework/interface.go +++ b/pkg/controller/jobframework/interface.go @@ -32,9 +32,9 @@ type GenericJob interface { // If true, status is modified, if not, status is as it was. ResetStatus() bool // RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job. - RunWithNodeAffinity(nodeSelectors []map[string]string) + RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector) // RestoreNodeAffinity will restore the original node affinity of job. - RestoreNodeAffinity(nodeSelectors []map[string]string) + RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector) // Finished means whether the job is completed/failed or not, // condition represents the workload finished condition. Finished() (condition metav1.Condition, finished bool) diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 6afe388f85..dc59aaaa03 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -307,7 +307,7 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error { //get the original selectors and store them in the job object originalSelectors := r.getNodeSelectorsFromPodSets(wl) - if err := nodeSelectorsSetToObject(object, originalSelectors); err != nil { + if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil { return fmt.Errorf("startJob, record original node selectors: %w", err) } @@ -385,17 +385,25 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o return wl, nil } +type PodSetNodeSelector struct { + Name string `json:"name"` + NodeSelector map[string]string `json:"nodeSelector"` +} + // getNodeSelectorsFromAdmission will extract node selectors from admitted workloads. -func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) { +func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) { if len(w.Status.Admission.PodSetAssignments) == 0 { return nil, nil } - nodeSelectors := make([]map[string]string, len(w.Status.Admission.PodSetAssignments)) + nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments)) for i, podSetFlavor := range w.Status.Admission.PodSetAssignments { processedFlvs := sets.NewString() - nodeSelector := map[string]string{} + nodeSelector := PodSetNodeSelector{ + Name: podSetFlavor.Name, + NodeSelector: make(map[string]string), + } for _, flvRef := range podSetFlavor.Flavors { flvName := string(flvRef) if processedFlvs.Has(flvName) { @@ -407,7 +415,7 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku return nil, err } for k, v := range flv.Spec.NodeLabels { - nodeSelector[k] = v + nodeSelector.NodeSelector[k] = v } processedFlvs.Insert(flvName) } @@ -418,14 +426,18 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku } // getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets. -func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []map[string]string { +func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector { podSets := w.Spec.PodSets if len(podSets) == 0 { return nil } - ret := make([]map[string]string, len(podSets)) + ret := make([]PodSetNodeSelector, len(podSets)) for psi := range podSets { - ret[psi] = cloneNodeSelector(podSets[psi].Template.Spec.NodeSelector) + ps := &podSets[psi] + ret[psi] = PodSetNodeSelector{ + Name: ps.Name, + NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector), + } } return ret } @@ -482,22 +494,22 @@ func cloneNodeSelector(src map[string]string) map[string]string { // getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the // object's annotations fails if it's not found or is unable to unmarshal -func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]map[string]string, error) { +func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) { str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] if !found { return nil, errNodeSelectorsNotFound } // unmarshal - ret := []map[string]string{} + ret := []PodSetNodeSelector{} if err := json.Unmarshal([]byte(str), &ret); err != nil { return nil, err } return ret, nil } -// nodeSelectorsSetToObject - sets an annotation containing the provided node selectors into +// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into // a job object, even if very unlikely it could return an error related to json.marshaling -func nodeSelectorsSetToObject(obj client.Object, nodeSelectors []map[string]string) error { +func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error { nodeSelectorsBytes, err := json.Marshal(nodeSelectors) if err != nil { return err diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index c8c67d7652..4a745777c9 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -83,7 +83,7 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) } } else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found { - out := []map[string]string{} + out := []PodSetNodeSelector{} if err := json.Unmarshal([]byte(av), &out); err != nil { allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error())) } diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index 5e9234beb8..06790c65a7 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -150,29 +150,29 @@ func (j *Job) PodSets() []kueue.PodSet { } } -func (j *Job) RunWithNodeAffinity(nodeSelectors []map[string]string) { +func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { j.Spec.Suspend = pointer.Bool(false) if len(nodeSelectors) == 0 { return } if j.Spec.Template.Spec.NodeSelector == nil { - j.Spec.Template.Spec.NodeSelector = nodeSelectors[0] + j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector } else { - for k, v := range nodeSelectors[0] { + for k, v := range nodeSelectors[0].NodeSelector { j.Spec.Template.Spec.NodeSelector[k] = v } } } -func (j *Job) RestoreNodeAffinity(nodeSelectors []map[string]string) { - if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0]) { +func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { + if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) { return } j.Spec.Template.Spec.NodeSelector = map[string]string{} - for k, v := range nodeSelectors[0] { + for k, v := range nodeSelectors[0].NodeSelector { j.Spec.Template.Spec.NodeSelector[k] = v } } diff --git a/pkg/controller/jobs/job/job_webhook_test.go b/pkg/controller/jobs/job/job_webhook_test.go index b605e62bde..2150dd9026 100644 --- a/pkg/controller/jobs/job/job_webhook_test.go +++ b/pkg/controller/jobs/job/job_webhook_test.go @@ -95,6 +95,17 @@ func TestValidateCreate(t *testing.T) { } func TestValidateUpdate(t *testing.T) { + + validPodSelectors := ` +[ + { + "name": "podSetName", + "nodeSelector": { + "l1": "v1" + } + } +] + ` testcases := []struct { name string oldJob *batchv1.Job @@ -161,18 +172,18 @@ func TestValidateUpdate(t *testing.T) { { name: "original node selectors can be set while unsuspending", oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(), - newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), wantErr: nil, }, { name: "original node selectors can be set while suspending", - oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(), - newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeJob("job", "default").Suspend(false).Obj(), + newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), wantErr: nil, }, { name: "immutable original node selectors while not suspended", - oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(), wantErr: field.ErrorList{ field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), @@ -180,7 +191,7 @@ func TestValidateUpdate(t *testing.T) { }, { name: "immutable original node selectors while suspended", - oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(), wantErr: field.ErrorList{ field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), diff --git a/pkg/controller/jobs/mpijob/mpijob_controller.go b/pkg/controller/jobs/mpijob/mpijob_controller.go index 8d4b0ebc33..55b638231c 100644 --- a/pkg/controller/jobs/mpijob/mpijob_controller.go +++ b/pkg/controller/jobs/mpijob/mpijob_controller.go @@ -104,20 +104,22 @@ func (j *MPIJob) PodSets() []kueue.PodSet { return podSets } -func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) { +func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { j.Spec.RunPolicy.Suspend = pointer.Bool(false) if len(nodeSelectors) == 0 { return } + // The node selectors are provided in the same order as the generated list of + // podSets, use the same ordering logic to restore them. orderedReplicaTypes := orderedReplicaTypes(&j.Spec) for index := range nodeSelectors { replicaType := orderedReplicaTypes[index] nodeSelector := nodeSelectors[index] - if len(nodeSelector) != 0 { + if len(nodeSelector.NodeSelector) != 0 { if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil { - j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector + j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector } else { - for k, v := range nodeSelector { + for k, v := range nodeSelector.NodeSelector { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v } } @@ -125,13 +127,13 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) { } } -func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []map[string]string) { +func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { orderedReplicaTypes := orderedReplicaTypes(&j.Spec) for index, nodeSelector := range nodeSelectors { replicaType := orderedReplicaTypes[index] if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{} - for k, v := range nodeSelector { + for k, v := range nodeSelector.NodeSelector { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v } } diff --git a/pkg/controller/jobs/mpijob/mpijob_webhook_test.go b/pkg/controller/jobs/mpijob/mpijob_webhook_test.go index 5e6291fdd4..98522f3179 100644 --- a/pkg/controller/jobs/mpijob/mpijob_webhook_test.go +++ b/pkg/controller/jobs/mpijob/mpijob_webhook_test.go @@ -33,6 +33,16 @@ var ( ) func TestUpdate(t *testing.T) { + validPodSelectors := ` +[ + { + "name": "podSetName", + "nodeSelector": { + "l1": "v1" + } + } +] +` testcases := map[string]struct { oldJob *kubeflow.MPIJob newJob *kubeflow.MPIJob @@ -40,23 +50,23 @@ func TestUpdate(t *testing.T) { }{ "original node selectors can be set while unsuspending": { oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(), - newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), wantErr: nil, }, "original node selectors can be set while suspending": { - oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(), - newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).Obj(), + newJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), wantErr: nil, }, "immutable original node selectors while not suspended": { - oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(), wantErr: field.ErrorList{ field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), }.ToAggregate(), }, "immutable original node selectors while suspended": { - oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(), + oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(), newJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(), wantErr: field.ErrorList{ field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),