Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt to commonize Kubeflow jobs Multikueue support methods #2795

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ type MultiKueueAdapter interface {
GVK() schema.GroupVersionKind
}

type UpdateRemoteJob interface {
mszadkow marked this conversation as resolved.
Show resolved Hide resolved
UpdateRemoteJobStatus(localJob, remoteJob interface{})
mszadkow marked this conversation as resolved.
Show resolved Hide resolved
UpdateRemoteJobSpec(localJob, remoteJob interface{})
mszadkow marked this conversation as resolved.
Show resolved Hide resolved
}

// MultiKueueWatcher optional interface that can be implemented by a MultiKueueAdapter
// to receive job related watch events from the worker cluster.
// If not implemented, MultiKueue will only receive events related to the job's workload.
Expand Down
89 changes: 89 additions & 0 deletions pkg/controller/jobs/kubeflow/common/common.go
mszadkow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package common

import (
"context"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"

"sigs.k8s.io/controller-runtime/pkg/client"
kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
)

type objAsPtr[T any] interface {
metav1.Object
client.Object
*T
}

func SyncJob[PtrT objAsPtr[T], T any](
mszadkow marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context,
localClient client.Client,
remoteClient client.Client,
key types.NamespacedName,
workloadName, origin string,
b jobframework.UpdateRemoteJob) error {

localJob := PtrT(new(T))
err := localClient.Get(ctx, key, localJob)
if err != nil {
return err
}

remoteJob := PtrT(new(T))
err = remoteClient.Get(ctx, key, remoteJob)
if client.IgnoreNotFound(err) != nil {
return err
}

if err == nil {
return clientutil.PatchStatus(ctx, localClient, localJob, func() (bool, error) {
// if the remote exists, just copy the status
b.UpdateRemoteJobStatus(localJob, remoteJob)
return true, nil
})
}

remoteJob = PtrT(new(T))
b.UpdateRemoteJobSpec(localJob, remoteJob)

// add the prebuilt workload
labels := remoteJob.GetLabels()
if remoteJob.GetLabels() == nil {
labels = make(map[string]string, 2)
}
labels[constants.PrebuiltWorkloadLabel] = workloadName
labels[kueuealpha.MultiKueueOriginLabel] = origin
remoteJob.SetLabels(labels)

return remoteClient.Create(ctx, remoteJob)
}

func DeleteRemoteObject[PtrT objAsPtr[T], T any](ctx context.Context, remoteClient client.Client, key types.NamespacedName) error {
job := PtrT(new(T))
err := remoteClient.Get(ctx, key, job)
if err != nil {
return client.IgnoreNotFound(err)
}
return client.IgnoreNotFound(remoteClient.Delete(ctx, job))
}

func WorkloadKeyFor[PtrT objAsPtr[T], T any](o runtime.Object, JobName string) (types.NamespacedName, error) {
job, isTheJob := o.(PtrT)
if !isTheJob {
return types.NamespacedName{}, fmt.Errorf("not a %s", JobName)
}

prebuiltWl, hasPrebuiltWorkload := job.GetLabels()[constants.PrebuiltWorkloadLabel]
if !hasPrebuiltWorkload {
return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for %s: %s", JobName, klog.KObj(job))
}

return types.NamespacedName{Name: prebuiltWl, Namespace: job.GetNamespace()}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,40 @@ package paddlejob

import (
"context"
"errors"
"fmt"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
kfcommon "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/common"
"sigs.k8s.io/kueue/pkg/util/api"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
)

type multikueueAdapter struct{}

var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil)
var _ jobframework.UpdateRemoteJob = (*multikueueAdapter)(nil)

func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error {
localJob := kftraining.PaddleJob{}
err := localClient.Get(ctx, key, &localJob)
if err != nil {
return err
}

remoteJob := &kftraining.PaddleJob{}
err = remoteClient.Get(ctx, key, remoteJob)
if client.IgnoreNotFound(err) != nil {
return err
}

// if the remote exists, just copy the status
if err == nil {
return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) {
localJob.Status = remoteJob.Status
return true, nil
})
}

remoteJob = &kftraining.PaddleJob{
ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta),
Spec: *localJob.Spec.DeepCopy(),
}
func (b *multikueueAdapter) UpdateRemoteJobStatus(localJob, remoteJob interface{}) {
localJob.(*kftraining.PaddleJob).Status = remoteJob.(*kftraining.PaddleJob).Status
}

// add the prebuilt workload
if remoteJob.Labels == nil {
remoteJob.Labels = make(map[string]string, 2)
func (b *multikueueAdapter) UpdateRemoteJobSpec(localJob, remoteJob interface{}) {
*remoteJob.(*kftraining.PaddleJob) = kftraining.PaddleJob{
ObjectMeta: api.CloneObjectMetaForCreation(&localJob.(*kftraining.PaddleJob).ObjectMeta),
Spec: *localJob.(*kftraining.PaddleJob).Spec.DeepCopy(),
}
remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName
remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin
}

return remoteClient.Create(ctx, remoteJob)
func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error {
return kfcommon.SyncJob[*kftraining.PaddleJob](ctx, localClient, remoteClient, key, workloadName, origin, b)
}

func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error {
job := kftraining.PaddleJob{}
err := remoteClient.Get(ctx, key, &job)
if err != nil {
return client.IgnoreNotFound(err)
}
return client.IgnoreNotFound(remoteClient.Delete(ctx, &job))
return kfcommon.DeleteRemoteObject[*kftraining.PaddleJob](ctx, remoteClient, key)
}

func (b *multikueueAdapter) KeepAdmissionCheckPending() bool {
Expand All @@ -103,15 +73,5 @@ func (*multikueueAdapter) GetEmptyList() client.ObjectList {
}

func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) {
paddleJob, isPaddleJob := o.(*kftraining.PaddleJob)
if !isPaddleJob {
return types.NamespacedName{}, errors.New("not a PaddleJob")
}

prebuiltWl, hasPrebuiltWorkload := paddleJob.Labels[constants.PrebuiltWorkloadLabel]
if !hasPrebuiltWorkload {
return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for PaddleJob: %s", klog.KObj(paddleJob))
}

return types.NamespacedName{Name: prebuiltWl, Namespace: paddleJob.Namespace}, nil
return kfcommon.WorkloadKeyFor[*kftraining.PaddleJob](o, kftraining.PaddleJobKind)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,40 @@ package pytorchjob

import (
"context"
"errors"
"fmt"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
kfcommon "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/common"
"sigs.k8s.io/kueue/pkg/util/api"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
)

type multikueueAdapter struct{}

var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil)
var _ jobframework.UpdateRemoteJob = (*multikueueAdapter)(nil)

func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error {
localJob := kftraining.PyTorchJob{}
err := localClient.Get(ctx, key, &localJob)
if err != nil {
return err
}

remoteJob := &kftraining.PyTorchJob{}
err = remoteClient.Get(ctx, key, remoteJob)
if client.IgnoreNotFound(err) != nil {
return err
}

// if the remote exists, just copy the status
if err == nil {
return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) {
localJob.Status = remoteJob.Status
return true, nil
})
}

remoteJob = &kftraining.PyTorchJob{
ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta),
Spec: *localJob.Spec.DeepCopy(),
}
func (b *multikueueAdapter) UpdateRemoteJobStatus(localJob, remoteJob interface{}) {
localJob.(*kftraining.PyTorchJob).Status = remoteJob.(*kftraining.PyTorchJob).Status
}

// add the prebuilt workload
if remoteJob.Labels == nil {
remoteJob.Labels = make(map[string]string, 2)
func (b *multikueueAdapter) UpdateRemoteJobSpec(localJob, remoteJob interface{}) {
*remoteJob.(*kftraining.PyTorchJob) = kftraining.PyTorchJob{
ObjectMeta: api.CloneObjectMetaForCreation(&localJob.(*kftraining.PyTorchJob).ObjectMeta),
Spec: *localJob.(*kftraining.PyTorchJob).Spec.DeepCopy(),
}
remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName
remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin
}

return remoteClient.Create(ctx, remoteJob)
func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error {
return kfcommon.SyncJob[*kftraining.PyTorchJob](ctx, localClient, remoteClient, key, workloadName, origin, b)
}

func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error {
job := kftraining.PyTorchJob{}
err := remoteClient.Get(ctx, key, &job)
if err != nil {
return client.IgnoreNotFound(err)
}
return client.IgnoreNotFound(remoteClient.Delete(ctx, &job))
return kfcommon.DeleteRemoteObject[*kftraining.PyTorchJob](ctx, remoteClient, key)
}

func (b *multikueueAdapter) KeepAdmissionCheckPending() bool {
Expand All @@ -103,15 +73,5 @@ func (*multikueueAdapter) GetEmptyList() client.ObjectList {
}

func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) {
pyTorchJob, isPyTorchJob := o.(*kftraining.PyTorchJob)
if !isPyTorchJob {
return types.NamespacedName{}, errors.New("not a PyTorchJob")
}

prebuiltWl, hasPrebuiltWorkload := pyTorchJob.Labels[constants.PrebuiltWorkloadLabel]
if !hasPrebuiltWorkload {
return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for PyTorchJob: %s", klog.KObj(pyTorchJob))
}

return types.NamespacedName{Name: prebuiltWl, Namespace: pyTorchJob.Namespace}, nil
return kfcommon.WorkloadKeyFor[*kftraining.PyTorchJob](o, kftraining.PyTorchJobKind)
}
Loading