Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate Job pod template updates to suspended jobs when resuming #590

Merged
merged 1 commit into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions pkg/controllers/jobset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ func (r *JobSetReconciler) suspendJobs(ctx context.Context, js *jobset.JobSet, a
// resumeJobsIfNecessary iterates through each replicatedJob, resuming any suspended jobs if the JobSet
// is not suspended.
func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, replicatedJobStatuses []jobset.ReplicatedJobStatus, updateStatusOpts *statusUpdateOpts) error {
// Store node selector for each replicatedJob template.
nodeAffinities := map[string]map[string]string{}
// Store pod template for each replicatedJob.
replicatedJobTemplateMap := map[string]corev1.PodTemplateSpec{}
for _, replicatedJob := range js.Spec.ReplicatedJobs {
nodeAffinities[replicatedJob.Name] = replicatedJob.Template.Spec.Template.Spec.NodeSelector
replicatedJobTemplateMap[replicatedJob.Name] = replicatedJob.Template.Spec.Template
}

// Map each replicatedJob to a list of its active jobs.
Expand All @@ -421,7 +421,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
if !jobSuspended(job) {
continue
}
if err := r.resumeJob(ctx, job, nodeAffinities); err != nil {
if err := r.resumeJob(ctx, job, replicatedJobTemplateMap); err != nil {
return err
}
}
Expand All @@ -439,7 +439,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
return nil
}

func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, nodeAffinities map[string]map[string]string) error {
func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, replicatedJobTemplateMap map[string]corev1.PodTemplateSpec) error {
log := ctrl.LoggerFrom(ctx)
// Kubernetes validates that a job template is immutable
// so if the job has started i.e., startTime != nil), we must set it to nil first.
Expand All @@ -449,10 +449,33 @@ func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, node
return err
}
}

// Get name of parent replicated job and use it to look up the pod template.
replicatedJobName := job.Labels[jobset.ReplicatedJobNameKey]
replicatedJobPodTemplate := replicatedJobTemplateMap[replicatedJobName]
if job.Labels != nil && job.Labels[jobset.ReplicatedJobNameKey] != "" {
// When resuming a job, its nodeSelectors should match that of the replicatedJob template
// that it was created from, which may have been updated while it was suspended.
job.Spec.Template.Spec.NodeSelector = nodeAffinities[job.Labels[jobset.ReplicatedJobNameKey]]
// Certain fields on the Job pod template may be mutated while a JobSet is suspended,
// for integration with Kueue. Ensure these updates are propagated to the child Jobs
// when the JobSet is resumed.
// Merge values rather than overwriting them, since a different controller
// (e.g., the Job controller) may have added labels/annotations/etc to the
// Job that do not exist in the ReplicatedJob pod template.
job.Spec.Template.Labels = collections.MergeMaps(
job.Spec.Template.Labels,
replicatedJobPodTemplate.Labels,
)
job.Spec.Template.Annotations = collections.MergeMaps(
job.Spec.Template.Annotations,
replicatedJobPodTemplate.Annotations,
)
job.Spec.Template.Spec.NodeSelector = collections.MergeMaps(
job.Spec.Template.Spec.NodeSelector,
replicatedJobPodTemplate.Spec.NodeSelector,
)
job.Spec.Template.Spec.Tolerations = collections.MergeSlices(
job.Spec.Template.Spec.Tolerations,
replicatedJobPodTemplate.Spec.Tolerations,
)
} else {
log.Error(nil, "job missing ReplicatedJobName label")
}
Expand Down
38 changes: 38 additions & 0 deletions pkg/util/collections/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,41 @@ func IndexOf[T comparable](slice []T, item T) int {
}
return -1
}

// MergeMaps will merge the `old` and `new` maps and return the
// merged map. If a key appears in both maps, the key-value pair
// in the `new` map will overwrite the value in the `old` map.
func MergeMaps[K comparable, V any](old, new map[K]V) map[K]V {
merged := make(map[K]V)
for k, v := range old {
merged[k] = v
}
for k, v := range new {
merged[k] = v // Overwrite if duplicate
}
return merged
}

func MergeSlices[T comparable](s1, s2 []T) []T {
mergedSet := make(map[T]bool)

// Add elements from s1 to the set
for _, item := range s1 {
mergedSet[item] = true
}

// Add elements from s2, only if they are not already in the set
for _, item := range s2 {
if _, exists := mergedSet[item]; !exists {
mergedSet[item] = true
}
}

// Convert the set back into a slice
mergedSlice := make([]T, 0, len(mergedSet))
for item := range mergedSet {
mergedSlice = append(mergedSlice, item)
}

return mergedSlice
}
88 changes: 88 additions & 0 deletions pkg/util/collections/collections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"golang.org/x/exp/slices"
)

func TestConcat(t *testing.T) {
Expand Down Expand Up @@ -151,3 +152,90 @@ func TestContains(t *testing.T) {
})
}
}

func TestMergeMaps(t *testing.T) {
testCases := []struct {
name string
m1 map[string]int
m2 map[string]int
expected map[string]int
}{
{
name: "Basic merge",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{"c": 3, "d": 4},
expected: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4},
},
{
name: "Overlapping keys",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{"b": 3, "c": 4},
expected: map[string]int{"a": 1, "b": 3, "c": 4}, // m2 value for 'b' overwrites
},
{
name: "Empty maps",
m1: map[string]int{},
m2: map[string]int{},
expected: map[string]int{},
},
{
name: "One empty map",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{},
expected: map[string]int{"a": 1, "b": 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
merged := MergeMaps(tc.m1, tc.m2)

if !reflect.DeepEqual(merged, tc.expected) {
t.Errorf("expected %v, got %v", tc.expected, merged)
}
})
}
}

func TestMergeSlices(t *testing.T) {
testCases := []struct {
name string
s1 []int
s2 []int
expected []int
}{
{
name: "merge with overlapping elements should not result in duplicates",
s1: []int{1, 2, 3},
s2: []int{3, 4, 5},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "empty slices",
s1: []int{},
s2: []int{},
expected: []int{},
},
{
name: "one empty slice",
s1: []int{1, 2},
s2: []int{},
expected: []int{1, 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
merged := MergeSlices(tc.s1, tc.s2)

// Sort before comparison so slices with the same elements
// should be the same.
slices.Sort(merged)
slices.Sort(tc.expected)

if !reflect.DeepEqual(merged, tc.expected) {
t.Errorf("Expected %v, got %v", tc.expected, merged)
}
})
}
}
80 changes: 63 additions & 17 deletions test/integration/controller/jobset_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ var _ = ginkgo.Describe("JobSet controller", func() {
updates []*update
}

nodeSelectors := map[string]map[string]string{
"replicated-job-a": {"node-selector-test-a": "node-selector-test-a"},
"replicated-job-b": {"node-selector-test-b": "node-selector-test-b"},
var podTemplateUpdates = &updatePodTemplateOpts{
labels: map[string]string{"label": "value"},
annotations: map[string]string{"annotation": "value"},
nodeSelector: map[string]string{"node-selector-test-a": "node-selector-test-a"},
tolerations: []corev1.Toleration{
{
Key: "key",
Operator: corev1.TolerationOpExists,
},
},
}

ginkgo.DescribeTable("jobset is created and its jobs go through a series of updates",
Expand Down Expand Up @@ -917,7 +924,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
},
{
jobSetUpdateFn: func(js *jobset.JobSet) {
updateJobSetNodeSelectors(js, nodeSelectors)
updatePodTemplates(js, podTemplateUpdates)
},
checkJobSetState: func(js *jobset.JobSet) {
ginkgo.By("Check ReplicatedJobStatus for suspend")
Expand Down Expand Up @@ -945,7 +952,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
{
checkJobSetState: func(js *jobset.JobSet) {
ginkgo.By("checking jobs have expected node selectors")
gomega.Eventually(matchJobsNodeSelectors, timeout, interval).WithArguments(js, nodeSelectors).Should(gomega.Equal(true))
gomega.Eventually(checkPodTemplateUpdates, timeout, interval).WithArguments(js, podTemplateUpdates).Should(gomega.Equal(true))
},
jobUpdateFn: completeAllJobs,
checkJobSetCondition: testutil.JobSetCompleted,
Expand Down Expand Up @@ -1905,15 +1912,35 @@ func suspendJobSet(js *jobset.JobSet, suspend bool) {
}, timeout, interval).Should(gomega.Succeed())
}

func updateJobSetNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) {
// updatePodTemplateOpts contains pod template values
// which can be mutated on a ReplicatedJob template
// while a JobSet is suspended.
type updatePodTemplateOpts struct {
labels map[string]string
annotations map[string]string
nodeSelector map[string]string
tolerations []corev1.Toleration
}

func updatePodTemplates(js *jobset.JobSet, opts *updatePodTemplateOpts) {
gomega.Eventually(func() error {
var jsGet jobset.JobSet
if err := k8sClient.Get(ctx, types.NamespacedName{Name: js.Name, Namespace: js.Namespace}, &jsGet); err != nil {
return err
}
for index := range jsGet.Spec.ReplicatedJobs {
jsGet.Spec.ReplicatedJobs[index].
Template.Spec.Template.Spec.NodeSelector = nodeSelectors[jsGet.Spec.ReplicatedJobs[index].Name]
podTemplate := &jsGet.Spec.ReplicatedJobs[index].Template.Spec.Template
// Update labels.
podTemplate.Labels = opts.labels

// Update annotations.
podTemplate.Annotations = opts.annotations

// Update node selector.
podTemplate.Spec.NodeSelector = opts.nodeSelector

// Update tolerations.
podTemplate.Spec.Tolerations = opts.tolerations
}
return k8sClient.Update(ctx, &jsGet)
}, timeout, interval).Should(gomega.Succeed())
Expand All @@ -1937,29 +1964,48 @@ func matchJobsSuspendState(js *jobset.JobSet, suspend bool) (bool, error) {
return true, nil
}

func matchJobsNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) (bool, error) {
func checkPodTemplateUpdates(js *jobset.JobSet, podTemplateUpdates *updatePodTemplateOpts) (bool, error) {
var jobList batchv1.JobList
if err := k8sClient.List(ctx, &jobList, client.InNamespace(js.Namespace)); err != nil {
return false, err
}
// Count number of updated jobs
jobsUpdated := 0
for _, job := range jobList.Items {
rjobName, ok := job.Labels[jobset.ReplicatedJobNameKey]
if !ok {
return false, fmt.Errorf(fmt.Sprintf("%s job missing ReplicatedJobName label", job.Name))
// Check label was added.
for label, value := range podTemplateUpdates.labels {
if job.Spec.Template.Labels[label] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[label], value)
}
}
if !apiequality.Semantic.DeepEqual(job.Spec.Template.Spec.NodeSelector, nodeSelectors[rjobName]) {
return false, nil

// Check annotation was added.
for annotation, value := range podTemplateUpdates.annotations {
if job.Spec.Template.Annotations[annotation] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[annotation], value)
}
}

// Check nodeSelector was updated.
for label, value := range podTemplateUpdates.nodeSelector {
if job.Spec.Template.Spec.NodeSelector[label] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Spec.NodeSelector[label], value)
}
}

// Check tolerations were updated.
for _, toleration := range podTemplateUpdates.tolerations {
if !collections.Contains(job.Spec.Template.Spec.Tolerations, toleration) {
return false, fmt.Errorf("missing toleration %v", toleration)
}
}

jobsUpdated++
}
// Calculate expected number of updated jobs
wantJobsUpdated := 0
for _, rjob := range js.Spec.ReplicatedJobs {
if _, exists := nodeSelectors[rjob.Name]; exists {
wantJobsUpdated += int(rjob.Replicas)
}
wantJobsUpdated += int(rjob.Replicas)
}
return wantJobsUpdated == jobsUpdated, nil
}
Expand Down