Skip to content

Commit

Permalink
Stateful set integration
Browse files Browse the repository at this point in the history
* Unit tests for webhook update
  • Loading branch information
vladikkuzn committed Oct 22, 2024
1 parent 48870dd commit a475247
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 13 deletions.
6 changes: 3 additions & 3 deletions config/components/manager/controller_manager_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ integrations:
- "kubeflow.org/pytorchjob"
- "kubeflow.org/tfjob"
- "kubeflow.org/xgboostjob"
# - "pod" (requires enabling pod integration)
# - "deployment" (requires enabling pod integration)
# - "statefulset" (requires enabling pod integration)
# - "pod" # requires enabling pod integration
# - "deployment" # requires enabling pod integration
# - "statefulset" # requires enabling pod integration
# externalFrameworks:
# - "Foo.v1.example.com"
# podOptions:
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobframework/integrationmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type IntegrationCallbacks struct {
CanSupportIntegration func(opts ...Option) (bool, error)
// The job's MultiKueue adapter (optional)
MultiKueueAdapter MultiKueueAdapter
// The list of integrations that needs to be enabled along with the current one.
// The list of integration that need to be enabled along with the current one.
DependencyList []string
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobs/deployment/deployment_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ var (
deploymentLabelsPath = field.NewPath("metadata", "labels")
deploymentQueueNameLabelPath = deploymentLabelsPath.Key(constants.QueueLabel)

podSpecLabelsPath = field.NewPath("spec", "template", "metadata", "labels")
podSpecQueueNameLabelPath = podSpecLabelsPath.Key(constants.QueueLabel)
podSpecQueueNameLabelPath = field.NewPath("spec", "template", "metadata", "labels").
Key(constants.QueueLabel)
)

func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) {
Expand Down
12 changes: 7 additions & 5 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ func (wh *Webhook) ValidateCreate(context.Context, runtime.Object) (warnings adm
var (
statefulsetLabelsPath = field.NewPath("metadata", "labels")
statefulsetQueueNameLabelPath = statefulsetLabelsPath.Key(constants.QueueLabel)
statefulsetReplicasPath = field.NewPath("spec", "replicas")
statefulsetGroupNameLabelPath = statefulsetLabelsPath.Key(pod.GroupNameLabel)

podSpecLabelsPath = field.NewPath("spec", "template", "metadata", "labels")
podSpecQueueNameLabelPath = podSpecLabelsPath.Key(constants.QueueLabel)
podSpecQueueNameLabelPath = field.NewPath("spec", "template", "metadata", "labels").
Key(constants.QueueLabel)
)

func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) {
Expand All @@ -124,14 +126,14 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.GetLabels()[pod.GroupNameLabel],
oldStatefulSet.GetLabels()[pod.GroupNameLabel],
statefulsetLabelsPath.Key(pod.GroupNameLabel),
statefulsetGroupNameLabelPath,
)...)

// Temporarily restrict to update replicas
// TODO(#3279): support resizes later
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Replicas,
oldStatefulSet.Spec.Replicas,
field.NewPath("spec", "replicas"),
statefulsetReplicasPath,
)...)

return warnings, allErrs.ToAggregate()
Expand Down
163 changes: 162 additions & 1 deletion pkg/controller/jobs/statefulset/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ limitations under the License.
package statefulset

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"

"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/pod"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
testingstatefulset "sigs.k8s.io/kueue/pkg/util/testingjobs/statefulset"
)
Expand All @@ -34,7 +42,7 @@ func TestDefault(t *testing.T) {
enableIntegrations []string
want *appsv1.StatefulSet
}{
"pod with queue": {
"statefulset with queue": {
enableIntegrations: []string{"pod"},
statefulset: testingstatefulset.MakeStatefulSet("test-pod", "").
Replicas(10).
Expand Down Expand Up @@ -86,3 +94,156 @@ func TestDefault(t *testing.T) {
})
}
}

func TestValidateUpdate(t *testing.T) {
testCases := map[string]struct {
oldObj *appsv1.StatefulSet
newObj *appsv1.StatefulSet
wantErr error
}{
"no changes": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
pod.GroupNameLabel: "group1",
},
},
Spec: appsv1.StatefulSetSpec{
Replicas: ptr.To(int32(3)),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
pod.GroupNameLabel: "group1",
},
},
Spec: appsv1.StatefulSetSpec{
Replicas: ptr.To(int32(3)),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
},
},
wantErr: nil,
},
"change in queue label": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue2",
},
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: statefulsetQueueNameLabelPath.String(),
},
}.ToAggregate(),
},
"change in pod template queue label": {
oldObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
},
},
newObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue2",
},
},
},
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: podSpecQueueNameLabelPath.String(),
},
}.ToAggregate(),
},
"change in group name label": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
pod.GroupNameLabel: "group1",
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
pod.GroupNameLabel: "group2",
},
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: statefulsetGroupNameLabelPath.String(),
},
}.ToAggregate(),
},
"change in replicas": {
oldObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Replicas: ptr.To(int32(3)),
},
},
newObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Replicas: ptr.To(int32(4)),
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: statefulsetReplicasPath.String(),
},
}.ToAggregate(),
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ctx := context.Background()

wh := &Webhook{}

_, err := wh.ValidateUpdate(ctx, tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantErr, err, cmpopts.IgnoreFields(field.Error{}, "BadValue", "Detail")); diff != "" {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
})
}
}
2 changes: 1 addition & 1 deletion test/e2e/singlecluster/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ var _ = ginkgo.BeforeSuite(func() {
waitForAvailableStart := time.Now()
util.WaitForKueueAvailability(ctx, k8sClient)
util.WaitForJobSetAvailability(ctx, k8sClient)
ginkgo.GinkgoLogr.Info("Kueue and JobSet operators are available in the cluster", "waitingTime", time.Since(waitForAvailableStart))
ginkgo.GinkgoLogr.Info("Kueue and JobSet oprators are available in the cluster", "waitingTime", time.Since(waitForAvailableStart))
})

0 comments on commit a475247

Please sign in to comment.