Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jobs]Add max execution time #3191

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/controller/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,7 @@ const (

// ProvReqAnnotationPrefix is the prefix for annotations that should be pass to ProvisioningRequest as Parameters.
ProvReqAnnotationPrefix = "provreq.kueue.x-k8s.io/"

// MaxExecTimeSecondsLabel is the label key in the job that holds the maximum execution time.
MaxExecTimeSecondsLabel = `kueue.x-k8s.io/max-exec-time-seconds`
)
16 changes: 16 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package jobframework

import (
"context"
"strconv"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
Expand Down Expand Up @@ -144,6 +146,20 @@ func QueueNameForObject(object client.Object) string {
return object.GetAnnotations()[constants.QueueAnnotation]
}

func MaximumExecutionTimeSeconds(job GenericJob) *int32 {
strVal, found := job.Object().GetLabels()[constants.MaxExecTimeSecondsLabel]
if !found {
return nil
}

v, err := strconv.ParseInt(strVal, 10, 32)
if err != nil || v <= 0 {
return nil
}

return ptr.To(int32(v))
}

func workloadPriorityClassName(job GenericJob) string {
object := job.Object()
if workloadPriorityClassLabel := object.GetLabels()[constants.WorkloadPriorityClassLabel]; workloadPriorityClassLabel != "" {
Expand Down
11 changes: 9 additions & 2 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -761,6 +762,11 @@ func equivalentToWorkload(ctx context.Context, c client.Client, job GenericJob,
return false
}

defaultDuration := int32(-1)
if ptr.Deref(wl.Spec.MaximumExecutionTimeSeconds, defaultDuration) != ptr.Deref(MaximumExecutionTimeSeconds(job), defaultDuration) {
return false
}

jobPodSets := clearMinCountsIfFeatureDisabled(job.PodSets())

if runningPodSets := expectedRunningPodSets(ctx, c, wl); runningPodSets != nil {
Expand Down Expand Up @@ -900,8 +906,9 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
Annotations: admissioncheck.FilterProvReqAnnotations(job.Object().GetAnnotations()),
},
Spec: kueue.WorkloadSpec{
PodSets: podSets,
QueueName: QueueName(job),
PodSets: podSets,
QueueName: QueueName(job),
MaximumExecutionTimeSeconds: MaximumExecutionTimeSeconds(job),
},
}
if wl.Labels == nil {
Expand Down
28 changes: 27 additions & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package jobframework

import (
"fmt"
"strconv"
"strings"

kfmpi "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
Expand All @@ -36,6 +37,7 @@ var (
annotationsPath = field.NewPath("metadata", "annotations")
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
maxExecTimeLabelPath = labelsPath.Key(constants.MaxExecTimeSecondsLabel)
workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
supportedPrebuiltWlJobGVKs = sets.New(
batchv1.SchemeGroupVersion.WithKind("Job").String(),
Expand All @@ -49,13 +51,16 @@ var (

// ValidateJobOnCreate encapsulates all GenericJob validations that must be performed on a Create operation
func ValidateJobOnCreate(job GenericJob) field.ErrorList {
return validateCreateForQueueName(job)
allErrs := validateCreateForQueueName(job)
allErrs = append(allErrs, validateCreateForMaxExecTime(job)...)
return allErrs
}

// ValidateJobOnUpdate encapsulates all GenericJob validations that must be performed on a Update operation
func ValidateJobOnUpdate(oldJob, newJob GenericJob) field.ErrorList {
allErrs := validateUpdateForQueueName(oldJob, newJob)
allErrs = append(allErrs, validateUpdateForWorkloadPriorityClassName(oldJob, newJob)...)
allErrs = append(allErrs, validateUpdateForMaxExecTime(oldJob, newJob)...)
return allErrs
}

Expand Down Expand Up @@ -111,3 +116,24 @@ func validateUpdateForWorkloadPriorityClassName(oldJob, newJob GenericJob) field
allErrs := apivalidation.ValidateImmutableField(workloadPriorityClassName(newJob), workloadPriorityClassName(oldJob), workloadPriorityClassNamePath)
return allErrs
}

func validateCreateForMaxExecTime(job GenericJob) field.ErrorList {
if strVal, found := job.Object().GetLabels()[constants.MaxExecTimeSecondsLabel]; found {
v, err := strconv.Atoi(strVal)
if err != nil {
return field.ErrorList{field.Invalid(maxExecTimeLabelPath, strVal, err.Error())}
}

if v <= 0 {
return field.ErrorList{field.Invalid(maxExecTimeLabelPath, v, "should be greater than 0")}
}
}
return nil
}

func validateUpdateForMaxExecTime(oldJob, newJob GenericJob) field.ErrorList {
if !newJob.IsSuspended() || !oldJob.IsSuspended() {
return apivalidation.ValidateImmutableField(newJob.Object().GetLabels()[constants.MaxExecTimeSecondsLabel], oldJob.Object().GetLabels()[constants.MaxExecTimeSecondsLabel], maxExecTimeLabelPath)
}
return nil
}
62 changes: 62 additions & 0 deletions pkg/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,68 @@ func TestReconciler(t *testing.T) {
},
},
},
"the maximum execution time is passed to the created workload": {
job: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantJob: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(10).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantEvents: []utiltesting.EventRecord{
{
Key: types.NamespacedName{Name: "job", Namespace: "ns"},
EventType: "Normal",
Reason: "CreatedWorkload",
Message: "Created Workload: ns/" + GetWorkloadNameForJob(baseJobWrapper.Name, baseJobWrapper.GetUID()),
},
},
},
"the maximum execution time is updated in the workload": {
job: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantJob: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(5).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(10).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantEvents: []utiltesting.EventRecord{
{
Key: types.NamespacedName{Name: "job", Namespace: "ns"},
EventType: "Normal",
Reason: "UpdatedWorkload",
Message: "Updated not matching Workload for suspended job: ns/job",
},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
85 changes: 85 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ var (
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
prebuiltWlNameLabelPath = labelsPath.Key(constants.PrebuiltWorkloadLabel)
maxExecTimeLabelPath = labelsPath.Key(constants.MaxExecTimeSecondsLabel)
queueNameAnnotationsPath = annotationsPath.Key(constants.QueueAnnotation)
workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
)
Expand Down Expand Up @@ -221,6 +222,55 @@ func TestValidateCreate(t *testing.T) {
wantErr: nil,
serverVersion: "1.27.0",
},
{
name: "invalid maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "NaN").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, "NaN", `strconv.Atoi: parsing "NaN": invalid syntax`),
},
serverVersion: "1.31.0",
},
{
name: "zero maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "0").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, 0, "should be greater than 0"),
},
serverVersion: "1.31.0",
},
{
name: "negative maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "-10").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, -10, "should be greater than 0"),
},
serverVersion: "1.31.0",
},
{
name: "valid maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "10").
Indexed(true).
Obj(),
serverVersion: "1.31.0",
},
}

for _, tc := range testcases {
Expand Down Expand Up @@ -430,6 +480,41 @@ func TestValidateUpdate(t *testing.T) {
field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"),
},
},
{
name: "immutable max exec time while unsuspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
wantErr: apivalidation.ValidateImmutableField("20", "10", maxExecTimeLabelPath),
},
{
name: "immutable max exec time while transitioning to unsuspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
wantErr: apivalidation.ValidateImmutableField("20", "10", maxExecTimeLabelPath),
},
{
name: "mutable max exec time while suspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
},
}

for _, tc := range testcases {
Expand Down
8 changes: 7 additions & 1 deletion pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ func (p *Pod) ConstructComposableWorkload(ctx context.Context, c client.Client,
Annotations: admissioncheck.FilterProvReqAnnotations(p.pod.GetAnnotations()),
},
Spec: kueue.WorkloadSpec{
QueueName: jobframework.QueueName(p),
QueueName: jobframework.QueueName(p),
MaximumExecutionTimeSeconds: jobframework.MaximumExecutionTimeSeconds(p),
},
}

Expand Down Expand Up @@ -1110,6 +1111,11 @@ func (p *Pod) FindMatchingWorkloads(ctx context.Context, c client.Client, r reco
return nil, nil, err
}

defaultDuration := int32(-1)
if ptr.Deref(workload.Spec.MaximumExecutionTimeSeconds, defaultDuration) != ptr.Deref(jobframework.MaximumExecutionTimeSeconds(p), defaultDuration) {
return nil, []*kueue.Workload{workload}, nil
}

// Cleanup excess pods for each workload pod set (role)
activePods := p.runnableOrSucceededPods()
inactivePods := p.notRunnableNorSucceededPods()
Expand Down
Loading