Skip to content

Commit

Permalink
[jobs]Add max execution time (kubernetes-sigs#3191)
Browse files Browse the repository at this point in the history
* [controllers][jobs] Add MaxExecutionTime to workloads.

* [webhooks][jobs] Validate MaxExecutionTime.

* [controllers][jobs] Update the workload when MaxExecutionTime changes.

* Rename `MaxExecTime` to `MaximumExecutionTimeSeconds`

* Review Remarks
  • Loading branch information
trasc authored and PBundyra committed Nov 5, 2024
1 parent 14bf308 commit cce4e8d
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pkg/controller/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,7 @@ const (

// ProvReqAnnotationPrefix is the prefix for annotations that should be pass to ProvisioningRequest as Parameters.
ProvReqAnnotationPrefix = "provreq.kueue.x-k8s.io/"

// MaxExecTimeSecondsLabel is the label key in the job that holds the maximum execution time.
MaxExecTimeSecondsLabel = `kueue.x-k8s.io/max-exec-time-seconds`
)
16 changes: 16 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package jobframework

import (
"context"
"strconv"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
Expand Down Expand Up @@ -144,6 +146,20 @@ func QueueNameForObject(object client.Object) string {
return object.GetAnnotations()[constants.QueueAnnotation]
}

func MaximumExecutionTimeSeconds(job GenericJob) *int32 {
strVal, found := job.Object().GetLabels()[constants.MaxExecTimeSecondsLabel]
if !found {
return nil
}

v, err := strconv.ParseInt(strVal, 10, 32)
if err != nil || v <= 0 {
return nil
}

return ptr.To(int32(v))
}

func workloadPriorityClassName(job GenericJob) string {
object := job.Object()
if workloadPriorityClassLabel := object.GetLabels()[constants.WorkloadPriorityClassLabel]; workloadPriorityClassLabel != "" {
Expand Down
11 changes: 9 additions & 2 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -761,6 +762,11 @@ func equivalentToWorkload(ctx context.Context, c client.Client, job GenericJob,
return false
}

defaultDuration := int32(-1)
if ptr.Deref(wl.Spec.MaximumExecutionTimeSeconds, defaultDuration) != ptr.Deref(MaximumExecutionTimeSeconds(job), defaultDuration) {
return false
}

jobPodSets := clearMinCountsIfFeatureDisabled(job.PodSets())

if runningPodSets := expectedRunningPodSets(ctx, c, wl); runningPodSets != nil {
Expand Down Expand Up @@ -900,8 +906,9 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
Annotations: admissioncheck.FilterProvReqAnnotations(job.Object().GetAnnotations()),
},
Spec: kueue.WorkloadSpec{
PodSets: podSets,
QueueName: QueueName(job),
PodSets: podSets,
QueueName: QueueName(job),
MaximumExecutionTimeSeconds: MaximumExecutionTimeSeconds(job),
},
}
if wl.Labels == nil {
Expand Down
28 changes: 27 additions & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package jobframework

import (
"fmt"
"strconv"
"strings"

kfmpi "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
Expand All @@ -36,6 +37,7 @@ var (
annotationsPath = field.NewPath("metadata", "annotations")
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
maxExecTimeLabelPath = labelsPath.Key(constants.MaxExecTimeSecondsLabel)
workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
supportedPrebuiltWlJobGVKs = sets.New(
batchv1.SchemeGroupVersion.WithKind("Job").String(),
Expand All @@ -49,13 +51,16 @@ var (

// ValidateJobOnCreate encapsulates all GenericJob validations that must be performed on a Create operation
func ValidateJobOnCreate(job GenericJob) field.ErrorList {
return validateCreateForQueueName(job)
allErrs := validateCreateForQueueName(job)
allErrs = append(allErrs, validateCreateForMaxExecTime(job)...)
return allErrs
}

// ValidateJobOnUpdate encapsulates all GenericJob validations that must be performed on a Update operation
func ValidateJobOnUpdate(oldJob, newJob GenericJob) field.ErrorList {
allErrs := validateUpdateForQueueName(oldJob, newJob)
allErrs = append(allErrs, validateUpdateForWorkloadPriorityClassName(oldJob, newJob)...)
allErrs = append(allErrs, validateUpdateForMaxExecTime(oldJob, newJob)...)
return allErrs
}

Expand Down Expand Up @@ -111,3 +116,24 @@ func validateUpdateForWorkloadPriorityClassName(oldJob, newJob GenericJob) field
allErrs := apivalidation.ValidateImmutableField(workloadPriorityClassName(newJob), workloadPriorityClassName(oldJob), workloadPriorityClassNamePath)
return allErrs
}

func validateCreateForMaxExecTime(job GenericJob) field.ErrorList {
if strVal, found := job.Object().GetLabels()[constants.MaxExecTimeSecondsLabel]; found {
v, err := strconv.Atoi(strVal)
if err != nil {
return field.ErrorList{field.Invalid(maxExecTimeLabelPath, strVal, err.Error())}
}

if v <= 0 {
return field.ErrorList{field.Invalid(maxExecTimeLabelPath, v, "should be greater than 0")}
}
}
return nil
}

func validateUpdateForMaxExecTime(oldJob, newJob GenericJob) field.ErrorList {
if !newJob.IsSuspended() || !oldJob.IsSuspended() {
return apivalidation.ValidateImmutableField(newJob.Object().GetLabels()[constants.MaxExecTimeSecondsLabel], oldJob.Object().GetLabels()[constants.MaxExecTimeSecondsLabel], maxExecTimeLabelPath)
}
return nil
}
62 changes: 62 additions & 0 deletions pkg/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,68 @@ func TestReconciler(t *testing.T) {
},
},
},
"the maximum execution time is passed to the created workload": {
job: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantJob: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(10).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantEvents: []utiltesting.EventRecord{
{
Key: types.NamespacedName{Name: "job", Namespace: "ns"},
EventType: "Normal",
Reason: "CreatedWorkload",
Message: "Created Workload: ns/" + GetWorkloadNameForJob(baseJobWrapper.Name, baseJobWrapper.GetUID()),
},
},
},
"the maximum execution time is updated in the workload": {
job: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
wantJob: *baseJobWrapper.Clone().
Label(controllerconsts.MaxExecTimeSecondsLabel, "10").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(5).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("job", "ns").
MaximumExecutionTimeSeconds(10).
Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").Obj()).
Queue("foo").
Priority(0).
Labels(map[string]string{controllerconsts.JobUIDLabel: string(baseJobWrapper.GetUID())}).
Obj(),
},
wantEvents: []utiltesting.EventRecord{
{
Key: types.NamespacedName{Name: "job", Namespace: "ns"},
EventType: "Normal",
Reason: "UpdatedWorkload",
Message: "Updated not matching Workload for suspended job: ns/job",
},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
85 changes: 85 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ var (
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
prebuiltWlNameLabelPath = labelsPath.Key(constants.PrebuiltWorkloadLabel)
maxExecTimeLabelPath = labelsPath.Key(constants.MaxExecTimeSecondsLabel)
queueNameAnnotationsPath = annotationsPath.Key(constants.QueueAnnotation)
workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
)
Expand Down Expand Up @@ -221,6 +222,55 @@ func TestValidateCreate(t *testing.T) {
wantErr: nil,
serverVersion: "1.27.0",
},
{
name: "invalid maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "NaN").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, "NaN", `strconv.Atoi: parsing "NaN": invalid syntax`),
},
serverVersion: "1.31.0",
},
{
name: "zero maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "0").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, 0, "should be greater than 0"),
},
serverVersion: "1.31.0",
},
{
name: "negative maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "-10").
Indexed(true).
Obj(),
wantErr: field.ErrorList{
field.Invalid(maxExecTimeLabelPath, -10, "should be greater than 0"),
},
serverVersion: "1.31.0",
},
{
name: "valid maximum execution time",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(4).
Label(constants.MaxExecTimeSecondsLabel, "10").
Indexed(true).
Obj(),
serverVersion: "1.31.0",
},
}

for _, tc := range testcases {
Expand Down Expand Up @@ -430,6 +480,41 @@ func TestValidateUpdate(t *testing.T) {
field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"),
},
},
{
name: "immutable max exec time while unsuspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
wantErr: apivalidation.ValidateImmutableField("20", "10", maxExecTimeLabelPath),
},
{
name: "immutable max exec time while transitioning to unsuspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
wantErr: apivalidation.ValidateImmutableField("20", "10", maxExecTimeLabelPath),
},
{
name: "mutable max exec time while suspended",
oldJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "10").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(true).
Label(constants.MaxExecTimeSecondsLabel, "20").
Obj(),
},
}

for _, tc := range testcases {
Expand Down
8 changes: 7 additions & 1 deletion pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ func (p *Pod) ConstructComposableWorkload(ctx context.Context, c client.Client,
Annotations: admissioncheck.FilterProvReqAnnotations(p.pod.GetAnnotations()),
},
Spec: kueue.WorkloadSpec{
QueueName: jobframework.QueueName(p),
QueueName: jobframework.QueueName(p),
MaximumExecutionTimeSeconds: jobframework.MaximumExecutionTimeSeconds(p),
},
}

Expand Down Expand Up @@ -1110,6 +1111,11 @@ func (p *Pod) FindMatchingWorkloads(ctx context.Context, c client.Client, r reco
return nil, nil, err
}

defaultDuration := int32(-1)
if ptr.Deref(workload.Spec.MaximumExecutionTimeSeconds, defaultDuration) != ptr.Deref(jobframework.MaximumExecutionTimeSeconds(p), defaultDuration) {
return nil, []*kueue.Workload{workload}, nil
}

// Cleanup excess pods for each workload pod set (role)
activePods := p.runnableOrSucceededPods()
inactivePods := p.notRunnableNorSucceededPods()
Expand Down
Loading

0 comments on commit cce4e8d

Please sign in to comment.