diff --git a/api/leaderworkerset/v1/leaderworkerset_types.go b/api/leaderworkerset/v1/leaderworkerset_types.go index 418d10e3..298e0070 100644 --- a/api/leaderworkerset/v1/leaderworkerset_types.go +++ b/api/leaderworkerset/v1/leaderworkerset_types.go @@ -67,6 +67,10 @@ const ( // address the leader via the headless service. LwsLeaderAddress string = "LWS_LEADER_ADDRESS" + // Environment variable added to all containers in the LeaderWorkerSet to + // track the size of the LWS group. + LwsGroupSize string = "LWS_GROUP_SIZE" + // Subgroup index tracks which subgroup the pod is part of. It will be added // as a label to the pod only if LeaderWorkerSet.Spec.SubGroupSize is set. SubGroupIndexLabelKey string = "leaderworkerset.sigs.k8s.io/subgroup-index" diff --git a/pkg/utils/pod/pod_utils.go b/pkg/utils/pod/pod_utils.go index 77540e42..e8d5eafd 100644 --- a/pkg/utils/pod/pod_utils.go +++ b/pkg/utils/pod/pod_utils.go @@ -20,6 +20,7 @@ import ( "fmt" corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" leaderworkerset "sigs.k8s.io/lws/api/leaderworkerset/v1" ) @@ -91,25 +92,31 @@ func getPodConditionFromList(conditions []corev1.PodCondition, conditionType cor return -1, nil } -func addEnvVarIfNotExists(c *corev1.Container, e corev1.EnvVar) { - for _, env := range c.Env { - if env.Name == e.Name { - return +func addEnvVarsIfNotExists(c *corev1.Container, e ...corev1.EnvVar) { + for _, newEnv := range e { + exists := false + for _, env := range c.Env { + if env.Name == newEnv.Name { + exists = true + break + } + } + if !exists { + c.Env = append([]corev1.EnvVar{newEnv}, c.Env...) } } - c.Env = append([]corev1.EnvVar{e}, c.Env...) } // AddLWSVariables adds LWS_LEADER_ADDRESS environment variable to every container. func AddLWSVariables(pod *corev1.Pod) error { lwsName, found := pod.Labels[leaderworkerset.SetNameLabelKey] if !found { - return fmt.Errorf("Failure constructing environment variables, no name label found for pod %v", pod.Name) + return fmt.Errorf("Failure constructing environment variables, no name label found for pod %v", klog.KObj(pod)) } groupIndex, found := pod.Labels[leaderworkerset.GroupIndexLabelKey] if !found { - return fmt.Errorf("Failure constructing environment variables, no group index label found for pod %v", pod.Name) + return fmt.Errorf("Failure constructing environment variables, no group index label found for pod %v", klog.KObj(pod)) } // The headless service name is assumed to be the same as the LWS name. @@ -119,11 +126,24 @@ func AddLWSVariables(pod *corev1.Pod) error { Value: fmt.Sprintf("%s-%s.%s.%s", lwsName, groupIndex, lwsName, pod.ObjectMeta.Namespace), } + size, found := pod.Annotations[leaderworkerset.SizeAnnotationKey] + if !found { + return fmt.Errorf("Failure constructing environment variables, no size annotation found for pod %v", klog.KObj(pod)) + } + + // The group size is assumed to be the same as the number of replicas. + sizeEnvVar := corev1.EnvVar{ + Name: leaderworkerset.LwsGroupSize, + Value: size, + } + + // The order of injection needs attention, see + // https://github.com/kubernetes-sigs/lws/pull/152. for i := range pod.Spec.Containers { - addEnvVarIfNotExists(&pod.Spec.Containers[i], leaderAddressEnvVar) + addEnvVarsIfNotExists(&pod.Spec.Containers[i], sizeEnvVar, leaderAddressEnvVar) } for i := range pod.Spec.InitContainers { - addEnvVarIfNotExists(&pod.Spec.InitContainers[i], leaderAddressEnvVar) + addEnvVarsIfNotExists(&pod.Spec.InitContainers[i], sizeEnvVar, leaderAddressEnvVar) } return nil diff --git a/pkg/utils/pod/pod_utils_test.go b/pkg/utils/pod/pod_utils_test.go index 10c96cb1..61ec278c 100644 --- a/pkg/utils/pod/pod_utils_test.go +++ b/pkg/utils/pod/pod_utils_test.go @@ -17,6 +17,7 @@ limitations under the License. package pod import ( + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -104,36 +105,43 @@ func TestAddLWSVariables(t *testing.T) { name string pod *corev1.Pod expectedLwsLeaderAddress string + expectedGroupSize int }{ { name: "Leader pod", - pod: testutils.MakePodWithLabels("test-sample", "0", "", "default"), + pod: testutils.MakePodWithLabels("test-sample", "0", "", "default", 2), expectedLwsLeaderAddress: "test-sample-0.test-sample.default", + expectedGroupSize: 2, }, { name: "Worker pod", - pod: testutils.MakePodWithLabels("test-sample", "0", "1", "default"), + pod: testutils.MakePodWithLabels("test-sample", "0", "1", "default", 2), expectedLwsLeaderAddress: "test-sample-0.test-sample.default", + expectedGroupSize: 2, }, { name: "Leader pod, group 1", - pod: testutils.MakePodWithLabels("test-sample", "1", "", "default"), + pod: testutils.MakePodWithLabels("test-sample", "1", "", "default", 2), expectedLwsLeaderAddress: "test-sample-1.test-sample.default", + expectedGroupSize: 2, }, { name: "Worker pod, group 1", - pod: testutils.MakePodWithLabels("test-sample", "1", "3", "default"), + pod: testutils.MakePodWithLabels("test-sample", "1", "3", "default", 2), expectedLwsLeaderAddress: "test-sample-1.test-sample.default", + expectedGroupSize: 2, }, { name: "Leader pod, group 1, non-default namespace", - pod: testutils.MakePodWithLabels("test-sample", "1", "3", "lws"), + pod: testutils.MakePodWithLabels("test-sample", "1", "3", "lws", 2), expectedLwsLeaderAddress: "test-sample-1.test-sample.lws", + expectedGroupSize: 2, }, { name: "Worker pod, group 1, non-default namespace", - pod: testutils.MakePodWithLabels("test-sample", "1", "3", "lws"), + pod: testutils.MakePodWithLabels("test-sample", "1", "3", "lws", 2), expectedLwsLeaderAddress: "test-sample-1.test-sample.lws", + expectedGroupSize: 2, }, } @@ -152,11 +160,14 @@ func TestAddLWSVariables(t *testing.T) { if len(container.Env) == 0 { t.Errorf("Failed to add LWS Variables to container %+v", container) } - envVar := container.Env[0] if diff := cmp.Diff(envVar.Value, tc.expectedLwsLeaderAddress); diff != "" { t.Errorf("Unexpected lws leader address %s", diff) } + envVar = container.Env[1] + if diff := cmp.Diff(envVar.Value, strconv.Itoa(tc.expectedGroupSize)); diff != "" { + t.Errorf("Unexpected lws group size %s", diff) + } } }) } diff --git a/test/integration/webhooks/pod_test.go b/test/integration/webhooks/pod_test.go index 4fe53266..3741c3d9 100644 --- a/test/integration/webhooks/pod_test.go +++ b/test/integration/webhooks/pod_test.go @@ -730,7 +730,7 @@ var _ = ginkgo.Describe("leaderworkerset pod defaulting, creation and update", f return nil }, }), - ginkgo.Entry("Leader address env var should be populated and is the first env var", &testDefaultingCase{ + ginkgo.Entry("Leader env var should be populated and leader address env should be the first env var", &testDefaultingCase{ makePod: func(ns *corev1.Namespace) corev1.Pod { return corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ @@ -750,12 +750,16 @@ var _ = ginkgo.Describe("leaderworkerset pod defaulting, creation and update", f }, checkExpectedPod: func(expected corev1.Pod, got corev1.Pod) error { if !testutils.HasLWSEnvVarsPopulated(got) { - return fmt.Errorf("should expect leader address env var for pod %s", got.Name) + return fmt.Errorf("should expect lws env var for pod %s", got.Name) } expectedLeaderAddress := fmt.Sprintf("test-sample-1.test-sample.%s", expected.ObjectMeta.Namespace) if err := testutils.CheckContainerHasCorrectEnvVar(got, corev1.EnvVar{Name: leaderworkerset.LwsLeaderAddress, Value: expectedLeaderAddress}); err != nil { return err } + expectedGroupSize := fmt.Sprintf("%d", 2) + if err := testutils.CheckContainerHasCorrectEnvVar(got, corev1.EnvVar{Name: leaderworkerset.LwsGroupSize, Value: expectedGroupSize}); err != nil { + return err + } if err := testutils.IsContainerFirstEnvVarLWSLeaderAddress(got); err != nil { return err } diff --git a/test/testutils/util.go b/test/testutils/util.go index e5ef4448..aa7e4982 100644 --- a/test/testutils/util.go +++ b/test/testutils/util.go @@ -342,7 +342,7 @@ func hasAllEnvVarPopulated(pod corev1.Pod, envVars []string) bool { } func HasLWSEnvVarsPopulated(pod corev1.Pod) bool { - return hasAllEnvVarPopulated(pod, []string{leaderworkerset.LwsLeaderAddress}) + return hasAllEnvVarPopulated(pod, []string{leaderworkerset.LwsLeaderAddress, leaderworkerset.LwsGroupSize}) } func CheckContainerHasCorrectEnvVar(pod corev1.Pod, expect corev1.EnvVar) error { diff --git a/test/testutils/wrappers.go b/test/testutils/wrappers.go index 3740c8b5..d17c6314 100644 --- a/test/testutils/wrappers.go +++ b/test/testutils/wrappers.go @@ -16,6 +16,7 @@ package testutils import ( "fmt" + "strconv" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -147,7 +148,7 @@ func BuildLeaderWorkerSet(nsName string) *LeaderWorkerSetWrapper { } } -func MakePodWithLabels(setName, groupIndex, workerIndex, namespace string) *corev1.Pod { +func MakePodWithLabels(setName, groupIndex, workerIndex, namespace string, size int) *corev1.Pod { podName := fmt.Sprintf("%s-%s-%s", setName, groupIndex, workerIndex) if workerIndex == "0" { podName = fmt.Sprintf("%s-%s", setName, groupIndex) @@ -161,6 +162,9 @@ func MakePodWithLabels(setName, groupIndex, workerIndex, namespace string) *core leaderworkerset.GroupIndexLabelKey: groupIndex, leaderworkerset.SetNameLabelKey: setName, }, + Annotations: map[string]string{ + leaderworkerset.SizeAnnotationKey: strconv.Itoa(size), + }, }, } }