Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record original node selectors #660

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/controller/jobframework/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ const (
// ignores this Job from admission, and takes control of its suspension
// status based on the admission status of the parent workload.
ParentWorkloadAnnotation = "kueue.x-k8s.io/parent-workload"

// OriginalNodeSelectorsAnnotation is the annotation in which the original
// node selectors are recorded upon a workload admission. This information
// will be used to restore them when the job is suspended.
// The content is a json marshaled slice of selectors.
OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we weighted the suggestion? #518 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it will unnecessary complicate the workload lifecycle. Also setting the annotation is done without any extra api calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annotation is also useful as a record for users to look at

)
4 changes: 2 additions & 2 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ type GenericJob interface {
// If true, status is modified, if not, status is as it was.
ResetStatus() bool
// RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithNodeAffinity(nodeSelectors []map[string]string)
RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RestoreNodeAffinity will restore the original node affinity of job.
RestoreNodeAffinity(podSets []kueue.PodSet)
RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector)
// Finished means whether the job is completed/failed or not,
// condition represents the workload finished condition.
Finished() (condition metav1.Condition, finished bool)
Expand Down
98 changes: 90 additions & 8 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package jobframework

import (
"context"
"encoding/json"
"fmt"

corev1 "k8s.io/api/core/v1"
Expand All @@ -34,6 +35,10 @@ import (
"sigs.k8s.io/kueue/pkg/workload"
)

var (
errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation)
)

// JobReconciler reconciles a GenericJob object
type JobReconciler struct {
client client.Client
Expand Down Expand Up @@ -300,7 +305,13 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec

// startJob will unsuspend the job, and also inject the node affinity.
func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error {
nodeSelectors, err := r.getNodeSelectors(ctx, wl)
//get the original selectors and store them in the job object
originalSelectors := r.getNodeSelectorsFromPodSets(wl)
if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil {
return fmt.Errorf("startJob, record original node selectors: %w", err)
}

nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl)
if err != nil {
return err
}
Expand All @@ -318,6 +329,7 @@ func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object cli

// stopJob will suspend the job, and also restore node affinity, reset job status if needed.
func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload, eventMsg string) error {
log := ctrl.LoggerFrom(ctx)
// Suspend the job at first then we're able to update the scheduling directives.
job.Suspend()

Expand All @@ -333,8 +345,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}
}

if wl != nil {
job.RestoreNodeAffinity(wl.Spec.PodSets)
log.V(3).Info("restore node selectors from annotation")
selectors, err := getNodeSelectorsFromObjectAnnotation(object)
if err != nil {
log.V(3).Error(err, "Unable to get original node selectors")
} else {
job.RestoreNodeAffinity(selectors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could fallback into getting the selectors from the workload object

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was my original approach but was change during the review.

Copy link
Contributor

@mimowo mimowo Apr 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC in the original approach you first looked up the workload object and then fallback to annotation.

Anyway, I think that once we made the annotation immutable, we can fully rely on it, unless I'm missing some scenario.

If such a scenario exists my point was that we should add a comment why this is done (what is the scenario). Otherwise we will end up in a suspicious code which no-one remembers / knows why needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order was inverted due to performance reasons, getting the selectors from the workload wold not have needed additional un-marshaling.

The scenario wold be when a job is missing the annotation, could happen if the job (other than core.Job or MPIJob) in question is not blocking the change of the annotation while running.

Copy link
Contributor

@mimowo mimowo Apr 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but we have a webhook to prevent the change. So, the only scenario is that the webhook is malfunctioning for some reason and we have a bad actor.

However, if the webhook is malfunctioning and we have a bad actor, the actor could both modify the annotation and delete the workload so the fallback would not work either. So, iiuc, it would not be bullet proof either, just misleadingly making that impression.

return r.client.Update(ctx, object)
}

Expand Down Expand Up @@ -369,17 +385,25 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
return wl, nil
}

// getNodeSelectors will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) {
type PodSetNodeSelector struct {
Name string `json:"name"`
NodeSelector map[string]string `json:"nodeSelector"`
}

// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) {
if len(w.Status.Admission.PodSetAssignments) == 0 {
return nil, nil
}

nodeSelectors := make([]map[string]string, len(w.Status.Admission.PodSetAssignments))
nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments))

for i, podSetFlavor := range w.Status.Admission.PodSetAssignments {
processedFlvs := sets.NewString()
nodeSelector := map[string]string{}
nodeSelector := PodSetNodeSelector{
Name: podSetFlavor.Name,
NodeSelector: make(map[string]string),
}
for _, flvRef := range podSetFlavor.Flavors {
flvName := string(flvRef)
if processedFlvs.Has(flvName) {
Expand All @@ -391,7 +415,7 @@ func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload)
return nil, err
}
for k, v := range flv.Spec.NodeLabels {
nodeSelector[k] = v
nodeSelector.NodeSelector[k] = v
}
processedFlvs.Insert(flvName)
}
Expand All @@ -401,6 +425,23 @@ func (r *JobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload)
return nodeSelectors, nil
}

// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets.
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector {
podSets := w.Spec.PodSets
if len(podSets) == 0 {
return nil
}
ret := make([]PodSetNodeSelector, len(podSets))
for psi := range podSets {
ps := &podSets[psi]
ret[psi] = PodSetNodeSelector{
Name: ps.Name,
NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector),
}
}
return ret
}

func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job GenericJob, object client.Object) error {
log := ctrl.LoggerFrom(ctx)

Expand Down Expand Up @@ -442,3 +483,44 @@ func generatePodsReadyCondition(job GenericJob, wl *kueue.Workload) metav1.Condi
Message: message,
}
}

func cloneNodeSelector(src map[string]string) map[string]string {
ret := make(map[string]string, len(src))
for k, v := range src {
ret[k] = v
}
return ret
}

// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) {
str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
if !found {
return nil, errNodeSelectorsNotFound
}
// unmarshal
ret := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(str), &ret); err != nil {
return nil, err
}
return ret, nil
}

// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into
// a job object, even if very unlikely it could return an error related to json.marshaling
func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error {
nodeSelectorsBytes, err := json.Marshal(nodeSelectors)
if err != nil {
return err
}

annotations := obj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)}
} else {
annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes)
}
obj.SetAnnotations(annotations)
return nil
}
19 changes: 19 additions & 0 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package jobframework

import (
"encoding/json"
"strings"

apivalidation "k8s.io/apimachinery/pkg/api/validation"
Expand All @@ -26,6 +27,8 @@ var (
labelsPath = field.NewPath("metadata", "labels")
parentWorkloadKeyPath = annotationsPath.Key(ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(QueueLabel)

originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation)
)

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
Expand Down Expand Up @@ -71,3 +74,19 @@ func ValidateUpdateForParentWorkload(oldJob, newJob GenericJob) field.ErrorList
}
return allErrs
}

func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.ErrorList {
trasc marked this conversation as resolved.
Show resolved Hide resolved
var allErrs field.ErrorList
trasc marked this conversation as resolved.
Show resolved Hide resolved
if oldJob.IsSuspended() == newJob.IsSuspended() {
if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation],
newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation], originalNodeSelectorsWorkloadKeyPath); len(errList) > 0 {
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found {
out := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error()))
}
}
return allErrs
}
12 changes: 6 additions & 6 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,29 @@ func (j *Job) PodSets() []kueue.PodSet {
}
}

func (j *Job) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0]
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
} else {
for k, v := range nodeSelectors[0] {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
}

func (j *Job) RestoreNodeAffinity(podSets []kueue.PodSet) {
if len(podSets) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSets[0].Template.Spec.NodeSelector) {
func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range podSets[0].Template.Spec.NodeSelector {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
return allErrs
}
Expand Down
41 changes: 41 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ var (
parentWorkloadKeyPath = annotationsPath.Key(jobframework.ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(jobframework.QueueLabel)
queueNameAnnotationsPath = annotationsPath.Key(jobframework.QueueAnnotation)

originalNodeSelectorsKeyPath = annotationsPath.Key(jobframework.OriginalNodeSelectorsAnnotation)
)

func TestValidateCreate(t *testing.T) {
Expand Down Expand Up @@ -93,6 +95,17 @@ func TestValidateCreate(t *testing.T) {
}

func TestValidateUpdate(t *testing.T) {

validPodSelectors := `
[
{
"name": "podSetName",
"nodeSelector": {
"l1": "v1"
}
}
]
`
testcases := []struct {
name string
oldJob *batchv1.Job
Expand Down Expand Up @@ -156,6 +169,34 @@ func TestValidateUpdate(t *testing.T) {
field.Forbidden(parentWorkloadKeyPath, "this annotation is immutable"),
},
},
{
name: "original node selectors can be set while unsuspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "original node selectors can be set while suspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "immutable original node selectors while not suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
{
name: "immutable original node selectors while suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
}

for _, tc := range testcases {
Expand Down
17 changes: 9 additions & 8 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,35 +104,36 @@ func (j *MPIJob) PodSets() []kueue.PodSet {
return podSets
}

func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.RunPolicy.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
}
// The node selectors are provided in the same order as the generated list of
// podSets, use the same ordering logic to restore them.
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index := range nodeSelectors {
trasc marked this conversation as resolved.
Show resolved Hide resolved
replicaType := orderedReplicaTypes[index]
nodeSelector := nodeSelectors[index]
if len(nodeSelector) != 0 {
if len(nodeSelector.NodeSelector) != 0 {
if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector
} else {
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
}
}
}

func (j *MPIJob) RestoreNodeAffinity(podSets []kueue.PodSet) {
func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index := range podSets {
for index, nodeSelector := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
trasc marked this conversation as resolved.
Show resolved Hide resolved
nodeSelector := podSets[index].Template.Spec.NodeSelector
if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{}
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/controller/jobs/mpijob/mpijob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ func validateCreate(job jobframework.GenericJob) field.ErrorList {
// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error {
oldJob := oldObj.(*kubeflow.MPIJob)
oldGenJob := &MPIJob{*oldJob}
newJob := newObj.(*kubeflow.MPIJob)
newGenJob := &MPIJob{*newJob}
log := ctrl.LoggerFrom(ctx).WithName("job-webhook")
log.Info("Validating update", "job", klog.KObj(newJob))
return jobframework.ValidateUpdateForQueueName(&MPIJob{*oldJob}, &MPIJob{*newJob}).ToAggregate()
allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...)
return allErrs.ToAggregate()
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
Expand Down
Loading