diff --git a/pkg/controller/admissionchecks/multikueue/indexer_test.go b/pkg/controller/admissionchecks/multikueue/indexer_test.go index e65370d444..118243285f 100644 --- a/pkg/controller/admissionchecks/multikueue/indexer_test.go +++ b/pkg/controller/admissionchecks/multikueue/indexer_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -47,6 +48,7 @@ func getClientBuilder() (*fake.ClientBuilder, context.Context) { utilruntime.Must(kueue.AddToScheme(scheme)) utilruntime.Must(kueuealpha.AddToScheme(scheme)) utilruntime.Must(jobset.AddToScheme(scheme)) + utilruntime.Must(kftraining.AddToScheme(scheme)) ctx := context.Background() builder := fake.NewClientBuilder().WithScheme(scheme).WithObjects(&corev1.Namespace{ diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index c41873ff41..fad264f217 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" batchv1 "k8s.io/api/batch/v1" apivalidation "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/util/sets" @@ -35,8 +36,10 @@ var ( labelsPath = field.NewPath("metadata", "labels") queueNameLabelPath = labelsPath.Key(constants.QueueLabel) workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel) - supportedPrebuiltWlJobGVKs = sets.New(batchv1.SchemeGroupVersion.WithKind("Job").String(), - jobset.SchemeGroupVersion.WithKind("JobSet").String()) + supportedPrebuiltWlJobGVKs = sets.New( + batchv1.SchemeGroupVersion.WithKind("Job").String(), + jobset.SchemeGroupVersion.WithKind("JobSet").String(), + kftraining.SchemeGroupVersion.WithKind("TFJob").String()) ) // ValidateJobOnCreate encapsulates all GenericJob validations that must be performed on a Create operation diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go index 2a8cf68885..f605eb82d3 100644 --- a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go @@ -45,6 +45,7 @@ func init() { JobType: &kftraining.TFJob{}, AddToScheme: kftraining.AddToScheme, IsManagingObjectsOwner: isTFJob, + MultiKueueAdapter: &multikueueAdapter{}, })) } diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go new file mode 100644 index 0000000000..f4092b65f3 --- /dev/null +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go @@ -0,0 +1,114 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tfjob + +import ( + "context" + "errors" + "fmt" + + kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/util/api" +) + +type multikueueAdapter struct{} + +var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil) + +func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error { + localJob := kftraining.TFJob{} + err := localClient.Get(ctx, key, &localJob) + if err != nil { + return err + } + + remoteJob := &kftraining.TFJob{} + err = remoteClient.Get(ctx, key, remoteJob) + if client.IgnoreNotFound(err) != nil { + return err + } + + // if the remote exists, just copy the status + if err == nil { + localJob.Status = remoteJob.Status + return localClient.Status().Update(ctx, &localJob) + } + + remoteJob = &kftraining.TFJob{ + ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta), + Spec: *localJob.Spec.DeepCopy(), + } + + // add the prebuilt workload + if remoteJob.Labels == nil { + remoteJob.Labels = map[string]string{} + } + remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName + remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin + + return remoteClient.Create(ctx, remoteJob) +} + +func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { + job := kftraining.TFJob{} + err := remoteClient.Get(ctx, key, &job) + if err != nil { + return client.IgnoreNotFound(err) + } + return client.IgnoreNotFound(remoteClient.Delete(ctx, &job)) +} + +func (b *multikueueAdapter) KeepAdmissionCheckPending() bool { + return false +} + +func (b *multikueueAdapter) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) { + return true, "", nil +} + +func (b *multikueueAdapter) GVK() schema.GroupVersionKind { + return gvk +} + +var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil) + +func (*multikueueAdapter) GetEmptyList() client.ObjectList { + return &kftraining.TFJobList{} +} + +func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { + tfJob, isTfJob := o.(*kftraining.TFJob) + if !isTfJob { + return types.NamespacedName{}, errors.New("not a TF Job") + } + + prebuiltWl, hasPrebuiltWorkload := tfJob.Labels[constants.PrebuiltWorkloadLabel] + if !hasPrebuiltWorkload { + return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for TF Job: %s", klog.KObj(tfJob)) + } + + return types.NamespacedName{Name: prebuiltWl, Namespace: tfJob.Namespace}, nil +} diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go new file mode 100644 index 0000000000..4026237cf1 --- /dev/null +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go @@ -0,0 +1,160 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tfjob + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/util/slices" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + + kfutiltesting "sigs.k8s.io/kueue/pkg/util/testingjobs/tfjob" +) + +const ( + TestNamespace = "ns" +) + +func TestMultikueueAdapter(t *testing.T) { + objCheckOpts := []cmp.Option{ + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "ResourceVersion"), + cmpopts.EquateEmpty(), + } + + tfJobBuilder := kfutiltesting.MakeTFJob("tfjob1", TestNamespace).Queue("queue").Suspend(false) + + cases := map[string]struct { + managersTFJobs []kftraining.TFJob + workerTFJobs []kftraining.TFJob + + operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error + + wantError error + wantManagersTFJobs []kftraining.TFJob + wantWorkerTFJobs []kftraining.TFJob + }{ + "sync creates missing remote tfjob": { + managersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone().Obj(), + }, + operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}, "wl1", "origin1") + }, + + wantManagersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone().Obj(), + }, + wantWorkerTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone(). + Label(constants.PrebuiltWorkloadLabel, "wl1"). + Label(kueuealpha.MultiKueueOriginLabel, "origin1"). + Obj(), + }, + }, + "sync status from remote tfjob": { + managersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone().Obj(), + }, + workerTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone(). + Label(constants.PrebuiltWorkloadLabel, "wl1"). + Label(kueuealpha.MultiKueueOriginLabel, "origin1"). + StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). + Obj(), + }, + operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}, "wl1", "origin1") + }, + + wantManagersTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone(). + StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). + Obj(), + }, + wantWorkerTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone(). + Label(constants.PrebuiltWorkloadLabel, "wl1"). + Label(kueuealpha.MultiKueueOriginLabel, "origin1"). + StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). + Obj(), + }, + }, + "remote tfjob is deleted": { + workerTFJobs: []kftraining.TFJob{ + *tfJobBuilder.Clone(). + Label(constants.PrebuiltWorkloadLabel, "wl1"). + Label(kueuealpha.MultiKueueOriginLabel, "origin1"). + Obj(), + }, + operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}) + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + managerBuilder := utiltesting.NewClientBuilder(kftraining.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge}) + managerBuilder = managerBuilder.WithLists(&kftraining.TFJobList{Items: tc.managersTFJobs}) + managerBuilder = managerBuilder.WithStatusSubresource(slices.Map(tc.managersTFJobs, func(w *kftraining.TFJob) client.Object { return w })...) + managerClient := managerBuilder.Build() + + workerBuilder := utiltesting.NewClientBuilder(kftraining.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge}) + workerBuilder = workerBuilder.WithLists(&kftraining.TFJobList{Items: tc.workerTFJobs}) + workerClient := workerBuilder.Build() + + ctx, _ := utiltesting.ContextWithLog(t) + + adapter := &multikueueAdapter{} + + gotErr := tc.operation(ctx, adapter, managerClient, workerClient) + + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" { + t.Errorf("unexpected error (-want/+got):\n%s", diff) + } + + gotManagersTFJob := &kftraining.TFJobList{} + if err := managerClient.List(ctx, gotManagersTFJob); err != nil { + t.Errorf("unexpected list manager's tfjobs error %s", err) + } else { + if diff := cmp.Diff(tc.wantManagersTFJobs, gotManagersTFJob.Items, objCheckOpts...); diff != "" { + t.Errorf("unexpected manager's tfjobs (-want/+got):\n%s", diff) + } + } + + gotWorkerTFJobs := &kftraining.TFJobList{} + if err := workerClient.List(ctx, gotWorkerTFJobs); err != nil { + t.Errorf("unexpected list worker's tfjobs error %s", err) + } else { + if diff := cmp.Diff(tc.wantWorkerTFJobs, gotWorkerTFJobs.Items, objCheckOpts...); diff != "" { + t.Errorf("unexpected worker's tfjobs (-want/+got):\n%s", diff) + } + } + }) + } +}