Skip to content

Commit

Permalink
[jobframework] Record podSet name in originalNodeSelectors annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed Apr 3, 2023
1 parent 1022e23 commit aac2e55
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 35 deletions.
4 changes: 2 additions & 2 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ type GenericJob interface {
// If true, status is modified, if not, status is as it was.
ResetStatus() bool
// RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithNodeAffinity(nodeSelectors []map[string]string)
RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RestoreNodeAffinity will restore the original node affinity of job.
RestoreNodeAffinity(nodeSelectors []map[string]string)
RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector)
// Finished means whether the job is completed/failed or not,
// condition represents the workload finished condition.
Finished() (condition metav1.Condition, finished bool)
Expand Down
32 changes: 22 additions & 10 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,25 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
return wl, nil
}

type PodSetNodeSelector struct {
Name string `json:"name"`
NodeSelector map[string]string `json:"nodeSelector"`
}

// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]map[string]string, error) {
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) {
if len(w.Status.Admission.PodSetAssignments) == 0 {
return nil, nil
}

nodeSelectors := make([]map[string]string, len(w.Status.Admission.PodSetAssignments))
nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments))

for i, podSetFlavor := range w.Status.Admission.PodSetAssignments {
processedFlvs := sets.NewString()
nodeSelector := map[string]string{}
nodeSelector := PodSetNodeSelector{
Name: podSetFlavor.Name,
NodeSelector: make(map[string]string),
}
for _, flvRef := range podSetFlavor.Flavors {
flvName := string(flvRef)
if processedFlvs.Has(flvName) {
Expand All @@ -407,7 +415,7 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku
return nil, err
}
for k, v := range flv.Spec.NodeLabels {
nodeSelector[k] = v
nodeSelector.NodeSelector[k] = v
}
processedFlvs.Insert(flvName)
}
Expand All @@ -418,14 +426,18 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku
}

// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets.
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []map[string]string {
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector {
podSets := w.Spec.PodSets
if len(podSets) == 0 {
return nil
}
ret := make([]map[string]string, len(podSets))
ret := make([]PodSetNodeSelector, len(podSets))
for psi := range podSets {
ret[psi] = cloneNodeSelector(podSets[psi].Template.Spec.NodeSelector)
ps := &podSets[psi]
ret[psi] = PodSetNodeSelector{
Name: ps.Name,
NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector),
}
}
return ret
}
Expand Down Expand Up @@ -482,13 +494,13 @@ func cloneNodeSelector(src map[string]string) map[string]string {

// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]map[string]string, error) {
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) {
str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
if !found {
return nil, errNodeSelectorsNotFound
}
// unmarshal
ret := []map[string]string{}
ret := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(str), &ret); err != nil {
return nil, err
}
Expand All @@ -497,7 +509,7 @@ func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]map[string]strin

// nodeSelectorsSetToObject - sets an annotation containing the provided node selectors into
// a job object, even if very unlikely it could return an error related to json.marshaling
func nodeSelectorsSetToObject(obj client.Object, nodeSelectors []map[string]string) error {
func nodeSelectorsSetToObject(obj client.Object, nodeSelectors []PodSetNodeSelector) error {
nodeSelectorsBytes, err := json.Marshal(nodeSelectors)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found {
out := []map[string]string{}
out := []PodSetNodeSelector{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error()))
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,29 @@ func (j *Job) PodSets() []kueue.PodSet {
}
}

func (j *Job) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0]
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
} else {
for k, v := range nodeSelectors[0] {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
}

func (j *Job) RestoreNodeAffinity(nodeSelectors []map[string]string) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0]) {
func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range nodeSelectors[0] {
for k, v := range nodeSelectors[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down
21 changes: 16 additions & 5 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ func TestValidateCreate(t *testing.T) {
}

func TestValidateUpdate(t *testing.T) {

validPodSelectors := `
[
{
"name": "podSetName",
"nodeSelector": {
"l1": "v1"
}
}
]
`
testcases := []struct {
name string
oldJob *batchv1.Job
Expand Down Expand Up @@ -161,26 +172,26 @@ func TestValidateUpdate(t *testing.T) {
{
name: "original node selectors can be set while unsuspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "original node selectors can be set while suspending",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeJob("job", "default").Suspend(false).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
{
name: "immutable original node selectors while not suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
{
name: "immutable original node selectors while suspended",
oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
Expand Down
12 changes: 6 additions & 6 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (j *MPIJob) PodSets() []kueue.PodSet {
return podSets
}

func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) {
func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
j.Spec.RunPolicy.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
Expand All @@ -113,25 +113,25 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []map[string]string) {
for index := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
nodeSelector := nodeSelectors[index]
if len(nodeSelector) != 0 {
if len(nodeSelector.NodeSelector) != 0 {
if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector
} else {
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
}
}
}

func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []map[string]string) {
func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index, nodeSelector := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{}
for k, v := range nodeSelector {
for k, v := range nodeSelector.NodeSelector {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v
}
}
Expand Down
20 changes: 15 additions & 5 deletions pkg/controller/jobs/mpijob/mpijob_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,40 @@ var (
)

func TestUpdate(t *testing.T) {
validPodSelectors := `
[
{
"name": "podSetName",
"nodeSelector": {
"l1": "v1"
}
}
]
`
testcases := map[string]struct {
oldJob *kubeflow.MPIJob
newJob *kubeflow.MPIJob
wantErr error
}{
"original node selectors can be set while unsuspending": {
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
"original node selectors can be set while suspending": {
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
wantErr: nil,
},
"immutable original node selectors while not suspended": {
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(false).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
}.ToAggregate(),
},
"immutable original node selectors while suspended": {
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("[{\"l1\":\"v1\"}]").Obj(),
oldJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation(validPodSelectors).Obj(),
newJob: testingutil.MakeMPIJob("job", "default").Suspend(true).OriginalNodeSelectorsAnnotation("").Obj(),
wantErr: field.ErrorList{
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
Expand Down

0 comments on commit aac2e55

Please sign in to comment.