From 734562db58342e3e55fe00b3a774c7734212e982 Mon Sep 17 00:00:00 2001 From: drivebyer Date: Sat, 24 Jun 2023 18:15:12 +0800 Subject: [PATCH] Fix lost toleration when key duplicated --- pkg/util/merge/merge.go | 18 +++++--- pkg/util/merge/merge_test.go | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/pkg/util/merge/merge.go b/pkg/util/merge/merge.go index c2d7e190a..59020e3dd 100644 --- a/pkg/util/merge/merge.go +++ b/pkg/util/merge/merge.go @@ -507,11 +507,15 @@ func Tolerations(defaultTolerations, overrideTolerations []corev1.Toleration) [] mergedTolerations := make([]corev1.Toleration, 0) defaultMap := createTolerationsMap(defaultTolerations) for _, v := range overrideTolerations { - defaultMap[v.Key] = v + if _, ok := defaultMap[v.Key]; ok { + defaultMap[v.Key] = append(defaultMap[v.Key], v) + } else { + defaultMap[v.Key] = []corev1.Toleration{v} + } } for _, v := range defaultMap { - mergedTolerations = append(mergedTolerations, v) + mergedTolerations = append(mergedTolerations, v...) } if len(mergedTolerations) == 0 { @@ -525,10 +529,14 @@ func Tolerations(defaultTolerations, overrideTolerations []corev1.Toleration) [] return mergedTolerations } -func createTolerationsMap(tolerations []corev1.Toleration) map[string]corev1.Toleration { - tolerationsMap := make(map[string]corev1.Toleration) +func createTolerationsMap(tolerations []corev1.Toleration) map[string][]corev1.Toleration { + tolerationsMap := make(map[string][]corev1.Toleration) for _, t := range tolerations { - tolerationsMap[t.Key] = t + if _, ok := tolerationsMap[t.Key]; ok { + tolerationsMap[t.Key] = append(tolerationsMap[t.Key], t) + } else { + tolerationsMap[t.Key] = []corev1.Toleration{t} + } } return tolerationsMap } diff --git a/pkg/util/merge/merge_test.go b/pkg/util/merge/merge_test.go index 2d35355ac..d22445c56 100644 --- a/pkg/util/merge/merge_test.go +++ b/pkg/util/merge/merge_test.go @@ -670,3 +670,82 @@ func TestMergeHostAliases(t *testing.T) { assert.Equal(t, "1.2.3.5", merged[1].IP) assert.Equal(t, []string{"abc"}, merged[1].Hostnames) } + +func TestTolerations(t *testing.T) { + type args struct { + defaultTolerations []corev1.Toleration + overrideTolerations []corev1.Toleration + } + tests := []struct { + name string + args args + want []corev1.Toleration + }{ + { + name: "override tolerations is nil", + args: args{ + defaultTolerations: []corev1.Toleration{ + { + Key: "key1", + Value: "value1", + Operator: corev1.TolerationOpEqual, + }, + { + Key: "key1", + Value: "value2", + Operator: corev1.TolerationOpExists, + }, + }, + overrideTolerations: nil, + }, + want: []corev1.Toleration{ + { + Key: "key1", + Value: "value1", + Operator: corev1.TolerationOpEqual, + }, + { + Key: "key1", + Value: "value2", + Operator: corev1.TolerationOpExists, + }, + }, + }, + + { + name: "default tolerations is nil", + args: args{ + defaultTolerations: nil, + overrideTolerations: []corev1.Toleration{ + { + Key: "key1", + Value: "value1", + Operator: corev1.TolerationOpEqual, + }, + { + Key: "key1", + Value: "value2", + Operator: corev1.TolerationOpExists, + }, + }, + }, + want: []corev1.Toleration{ + { + Key: "key1", + Value: "value1", + Operator: corev1.TolerationOpEqual, + }, + { + Key: "key1", + Value: "value2", + Operator: corev1.TolerationOpExists, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, Tolerations(tt.args.defaultTolerations, tt.args.overrideTolerations), "Tolerations(%v, %v)", tt.args.defaultTolerations, tt.args.overrideTolerations) + }) + } +}