diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index 45dc562128..03103083c7 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -24,3 +24,23 @@ webhooks: resources: - pytorchjobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-xgboostjob + failurePolicy: Fail + name: vxgboostjob.kb.io + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - xgboostjobs + sideEffects: None diff --git a/pkg/apis/kubeflow.org/v1/xgboost_validation.go b/pkg/apis/kubeflow.org/v1/xgboost_validation.go deleted file mode 100644 index b862c52c85..0000000000 --- a/pkg/apis/kubeflow.org/v1/xgboost_validation.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2021 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" -) - -func ValidateV1XGBoostJob(xgboostJob *XGBoostJob) error { - if errors := apimachineryvalidation.NameIsDNS1035Label(xgboostJob.ObjectMeta.Name, false); errors != nil { - return fmt.Errorf("XGBoostJob name is invalid: %v", errors) - } - if err := validateXGBoostReplicaSpecs(xgboostJob.Spec.XGBReplicaSpecs); err != nil { - return err - } - return nil -} - -func validateXGBoostReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error { - if specs == nil { - return fmt.Errorf("XGBoostJobSpec is not valid") - } - masterExists := false - for rType, value := range specs { - if value == nil || len(value.Template.Spec.Containers) == 0 { - return fmt.Errorf("XGBoostJobSpec is not valid: containers definition expected in %v", rType) - } - // Make sure the replica type is valid. - validReplicaTypes := []ReplicaType{XGBoostJobReplicaTypeMaster, XGBoostJobReplicaTypeWorker} - - isValidReplicaType := false - for _, t := range validReplicaTypes { - if t == rType { - isValidReplicaType = true - break - } - } - - if !isValidReplicaType { - return fmt.Errorf("XGBoostReplicaType is %v but must be one of %v", rType, validReplicaTypes) - } - - //Make sure the image is defined in the container - defaultContainerPresent := false - for _, container := range value.Template.Spec.Containers { - if container.Image == "" { - msg := fmt.Sprintf("XGBoostReplicaType is not valid: Image is undefined in the container of %v", rType) - return fmt.Errorf(msg) - } - if container.Name == XGBoostJobDefaultContainerName { - defaultContainerPresent = true - } - } - //Make sure there has at least one container named "xgboost" - if !defaultContainerPresent { - msg := fmt.Sprintf("XGBoostReplicaType is not valid: There is no container named %s in %v", XGBoostJobDefaultContainerName, rType) - return fmt.Errorf(msg) - } - if rType == XGBoostJobReplicaTypeMaster { - masterExists = true - if value.Replicas != nil && int(*value.Replicas) != 1 { - return fmt.Errorf("XGBoostReplicaType is not valid: There must be only 1 master replica") - } - } - - } - - if !masterExists { - return fmt.Errorf("XGBoostReplicaType is not valid: Master ReplicaSpec must be present") - } - return nil - -} diff --git a/pkg/apis/kubeflow.org/v1/xgboost_validation_test.go b/pkg/apis/kubeflow.org/v1/xgboost_validation_test.go deleted file mode 100644 index 7447b3c4da..0000000000 --- a/pkg/apis/kubeflow.org/v1/xgboost_validation_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2021 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func TestValidateV1XGBoostJob(t *testing.T) { - validXGBoostReplicaSpecs := map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeMaster: { - Replicas: ptr.To[int32](1), - RestartPolicy: RestartPolicyNever, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "xgboost", - Image: "docker.io/kubeflow/xgboost-dist-iris:latest", - Ports: []corev1.ContainerPort{{ - Name: "xgboostjob-port", - ContainerPort: 9991, - }}, - ImagePullPolicy: corev1.PullAlways, - Args: []string{ - "--job_type=Train", - "--xgboost_parameter=objective:multi:softprob,num_class:3", - "--n_estimators=10", - "--learning_rate=0.1", - "--model_path=/tmp/xgboost-model", - "--model_storage_type=local", - }, - }}, - }, - }, - }, - XGBoostJobReplicaTypeWorker: { - Replicas: ptr.To[int32](2), - RestartPolicy: RestartPolicyExitCode, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "xgboost", - Image: "docker.io/kubeflow/xgboost-dist-iris:latest", - Ports: []corev1.ContainerPort{{ - Name: "xgboostjob-port", - ContainerPort: 9991, - }}, - ImagePullPolicy: corev1.PullAlways, - Args: []string{ - "--job_type=Train", - "--xgboost_parameter=objective:multi:softprob,num_class:3", - "--n_estimators=10", - "--learning_rate=0.1", - }, - }}, - }, - }, - }, - } - - testCases := map[string]struct { - xgboostJob *XGBoostJob - wantErr bool - }{ - "valid XGBoostJob": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: validXGBoostReplicaSpecs, - }, - }, - wantErr: false, - }, - "XGBoostJob name does not meet DNS1035": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "-test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: validXGBoostReplicaSpecs, - }, - }, - wantErr: true, - }, - "empty containers": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "image is empty": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "xgboost", - Image: "", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "xgboostJob default container name doesn't present": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "", - Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "the number of replicas in masterReplica is other than 1": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeMaster: { - Replicas: ptr.To[int32](2), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "xgboost", - Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "masterReplica does not exist": { - xgboostJob: &XGBoostJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: XGBoostJobSpec{ - XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - XGBoostJobReplicaTypeWorker: { - Replicas: ptr.To[int32](1), - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "xgboost", - Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - got := ValidateV1XGBoostJob(tc.xgboostJob) - if (got != nil) != tc.wantErr { - t.Fatalf("ValidateV1XGBoostJob() error = %v, wantErr %v", got, tc.wantErr) - } - }) - } -} diff --git a/pkg/controller.v1/xgboost/suite_test.go b/pkg/controller.v1/xgboost/suite_test.go index 42f57d2b09..a5a0614a37 100644 --- a/pkg/controller.v1/xgboost/suite_test.go +++ b/pkg/controller.v1/xgboost/suite_test.go @@ -16,7 +16,11 @@ package xgboost import ( "context" + "crypto/tls" + "fmt" + "net" "testing" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -28,10 +32,12 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + xgboostwebhook "github.com/kubeflow/training-operator/pkg/webhooks/xgboost" //+kubebuilder:scaffold:imports ) @@ -60,6 +66,9 @@ var _ = BeforeSuite(func() { testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook")}, + }, } cfg, err := testEnv.Start() @@ -81,6 +90,12 @@ var _ = BeforeSuite(func() { Metrics: metricsserver.Options{ BindAddress: "0", }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), }) Expect(err).NotTo(HaveOccurred()) @@ -88,12 +103,21 @@ var _ = BeforeSuite(func() { r := NewReconciler(mgr, gangSchedulingSetupFunc) Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(xgboostwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) go func() { defer GinkgoRecover() err = mgr.Start(testCtx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) }) var _ = AfterSuite(func() { diff --git a/pkg/controller.v1/xgboost/xgboostjob_controller.go b/pkg/controller.v1/xgboost/xgboostjob_controller.go index f8e2018311..fb860e462f 100644 --- a/pkg/controller.v1/xgboost/xgboostjob_controller.go +++ b/pkg/controller.v1/xgboost/xgboostjob_controller.go @@ -130,13 +130,6 @@ func (r *XGBoostJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, client.IgnoreNotFound(err) } - if err = kubeflowv1.ValidateV1XGBoostJob(xgboostjob); err != nil { - logger.Error(err, "XGBoostJob failed validation") - r.Recorder.Eventf(xgboostjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobFailedValidationReason), - "XGBoostJob failed validation because %s", err) - return ctrl.Result{}, err - } - // Check reconcile is required. jobKey, err := common.KeyFunc(xgboostjob) if err != nil { diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go index 5e97a3d3f3..733c9d04e4 100644 --- a/pkg/webhooks/webhooks.go +++ b/pkg/webhooks/webhooks.go @@ -21,6 +21,7 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" + "github.com/kubeflow/training-operator/pkg/webhooks/xgboost" ) type WebhookSetupFunc func(manager manager.Manager) error @@ -30,7 +31,7 @@ var ( trainingoperator.PyTorchJobKind: pytorch.SetupWebhook, trainingoperator.TFJobKind: scaffold, trainingoperator.MXJobKind: scaffold, - trainingoperator.XGBoostJobKind: scaffold, + trainingoperator.XGBoostJobKind: xgboost.SetupWebhook, trainingoperator.MPIJobKind: scaffold, trainingoperator.PaddleJobKind: scaffold, } diff --git a/pkg/webhooks/xgboost/xgboostjob_webhook.go b/pkg/webhooks/xgboost/xgboostjob_webhook.go new file mode 100644 index 0000000000..95af722fc6 --- /dev/null +++ b/pkg/webhooks/xgboost/xgboostjob_webhook.go @@ -0,0 +1,136 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "context" + "fmt" + "strings" + + "golang.org/x/exp/slices" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + xgbReplicaSpecPath = specPath.Child("xgbReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.XGBoostJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-xgboostjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=xgboostjobs,verbs=create;update,versions=v1,name=vxgboostjob.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.XGBoostJob) + log := ctrl.LoggerFrom(ctx).WithName("xgboostjob-webhook") + log.V(5).Info("Validating create", "xgboostJob", klog.KObj(job)) + return nil, validateXGBoostJob(job).ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.XGBoostJob) + log := ctrl.LoggerFrom(ctx).WithName("xgboostjob-webhook") + log.V(5).Info("Validating create", "xgboostJob", klog.KObj(job)) + return nil, validateXGBoostJob(job).ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validateXGBoostJob(job *trainingoperator.XGBoostJob) field.ErrorList { + var allErrs field.ErrorList + + if errors := apimachineryvalidation.NameIsDNS1035Label(job.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + allErrs = append(allErrs, validateSpec(job.Spec)...) + return allErrs +} + +func validateSpec(spec trainingoperator.XGBoostJobSpec) field.ErrorList { + return validateXGBReplicaSpecs(spec.XGBReplicaSpecs) +} + +func validateXGBReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(xgbReplicaSpecPath, "must be required")) + } + masterExists := false + for rType, rSpec := range rSpecs { + rolePath := xgbReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validReplicaTypes := []string{ + string(trainingoperator.XGBoostJobReplicaTypeMaster), + string(trainingoperator.XGBoostJobReplicaTypeWorker), + } + if !slices.Contains(validReplicaTypes, string(rType)) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validReplicaTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.XGBoostJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "xgboost" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.XGBoostJobDefaultContainerName))) + } + if rType == trainingoperator.XGBoostJobReplicaTypeMaster { + masterExists = true + if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 { + allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1")) + } + } + } + if !masterExists { + allErrs = append(allErrs, field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)), "must be present")) + } + return allErrs +} diff --git a/pkg/webhooks/xgboost/xgboostjob_webhook_test.go b/pkg/webhooks/xgboost/xgboostjob_webhook_test.go new file mode 100644 index 0000000000..9bd95893c6 --- /dev/null +++ b/pkg/webhooks/xgboost/xgboostjob_webhook_test.go @@ -0,0 +1,243 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateXGBoostJob(t *testing.T) { + validXGBoostReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + RestartPolicy: trainingoperator.RestartPolicyNever, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "xgboost", + Image: "docker.io/kubeflow/xgboost-dist-iris:latest", + Ports: []corev1.ContainerPort{{ + Name: "xgboostjob-port", + ContainerPort: 9991, + }}, + ImagePullPolicy: corev1.PullAlways, + Args: []string{ + "--job_type=Train", + "--xgboost_parameter=objective:multi:softprob,num_class:3", + "--n_estimators=10", + "--learning_rate=0.1", + "--model_path=/tmp/xgboost-model", + "--model_storage_type=local", + }, + }}, + }, + }, + }, + trainingoperator.XGBoostJobReplicaTypeWorker: { + Replicas: ptr.To[int32](2), + RestartPolicy: trainingoperator.RestartPolicyExitCode, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "xgboost", + Image: "docker.io/kubeflow/xgboost-dist-iris:latest", + Ports: []corev1.ContainerPort{{ + Name: "xgboostjob-port", + ContainerPort: 9991, + }}, + ImagePullPolicy: corev1.PullAlways, + Args: []string{ + "--job_type=Train", + "--xgboost_parameter=objective:multi:softprob,num_class:3", + "--n_estimators=10", + "--learning_rate=0.1", + }, + }}, + }, + }, + }, + } + + testCases := map[string]struct { + xgboostJob *trainingoperator.XGBoostJob + wantErr field.ErrorList + }{ + "valid XGBoostJob": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: validXGBoostReplicaSpecs, + }, + }, + }, + "XGBoostJob name does not meet DNS1035": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "-test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: validXGBoostReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "empty containers": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)).Child("template").Child("spec").Child("containers"), ""), + field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "image is empty": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "xgboost", + Image: "", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)).Child("template").Child("spec").Child("containers").Index(0).Child("image"), ""), + }, + }, + "xgboostJob default container name doesn't present": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "", + Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "the number of replicas in masterReplica is other than 1": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeMaster: { + Replicas: ptr.To[int32](2), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "xgboost", + Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)).Child("replicas"), ""), + }, + }, + "masterReplica does not exist": { + xgboostJob: &trainingoperator.XGBoostJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.XGBoostJobSpec{ + XGBReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.XGBoostJobReplicaTypeWorker: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "xgboost", + Image: "gcr.io/kubeflow-ci/xgboost-dist-mnist_test:1.0", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(xgbReplicaSpecPath.Key(string(trainingoperator.XGBoostJobReplicaTypeMaster)), ""), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := validateXGBoostJob(tc.xgboostJob) + if diff := cmp.Diff(tc.wantErr, got, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected errors (-want,+got):\n%s", diff) + } + }) + } +}