diff --git a/cmd/eksctl-anywhere/cmd/createcluster.go b/cmd/eksctl-anywhere/cmd/createcluster.go index c303091043984..bab390c26bf0c 100644 --- a/cmd/eksctl-anywhere/cmd/createcluster.go +++ b/cmd/eksctl-anywhere/cmd/createcluster.go @@ -119,7 +119,8 @@ func (cc *createClusterOptions) createCluster(cmd *cobra.Command, _ []string) er WithGitOpsFlux(clusterSpec.Cluster, clusterSpec.FluxConfig, cliConfig). WithWriter(). WithEksdInstaller(). - WithPackageInstaller(clusterSpec, cc.installPackages, cc.managementKubeconfig) + WithPackageInstaller(clusterSpec, cc.installPackages, cc.managementKubeconfig). + WithValidatorClients() if cc.timeoutOptions.noTimeouts { factory.WithNoTimeouts() @@ -142,7 +143,7 @@ func (cc *createClusterOptions) createCluster(cmd *cobra.Command, _ []string) er ) validationOpts := &validations.Opts{ - Kubectl: deps.Kubectl, + Kubectl: deps.UnAuthKubectlClient, Spec: clusterSpec, WorkloadCluster: &types.Cluster{ Name: clusterSpec.Cluster.Name, diff --git a/cmd/eksctl-anywhere/cmd/upgradecluster.go b/cmd/eksctl-anywhere/cmd/upgradecluster.go index 349e9f374cc8a..b5f634609fd4d 100644 --- a/cmd/eksctl-anywhere/cmd/upgradecluster.go +++ b/cmd/eksctl-anywhere/cmd/upgradecluster.go @@ -23,6 +23,7 @@ type upgradeClusterOptions struct { forceClean bool hardwareCSVPath string tinkerbellBootstrapIP string + skipValidations []string } var uc = &upgradeClusterOptions{} @@ -48,6 +49,7 @@ func init() { applyTinkerbellHardwareFlag(upgradeClusterCmd.Flags(), &uc.hardwareCSVPath) upgradeClusterCmd.Flags().StringVarP(&uc.wConfig, "w-config", "w", "", "Kubeconfig file to use when upgrading a workload cluster") upgradeClusterCmd.Flags().BoolVar(&uc.forceClean, "force-cleanup", false, "Force deletion of previously created bootstrap cluster") + upgradeClusterCmd.Flags().StringArrayVar(&uc.skipValidations, "skip-validations", []string{}, "Bypass upgrade validations by name. Valid arguments you can pass are --skip-validations=pod-disruption") if err := upgradeClusterCmd.MarkFlagRequired("filename"); err != nil { log.Fatalf("Error marking flag as required: %v", err) @@ -107,7 +109,8 @@ func (uc *upgradeClusterOptions) upgradeCluster(cmd *cobra.Command) error { WithCAPIManager(). WithEksdUpgrader(). WithEksdInstaller(). - WithKubectl() + WithKubectl(). + WithValidatorClients() if uc.timeoutOptions.noTimeouts { factory.WithNoTimeouts() @@ -144,13 +147,20 @@ func (uc *upgradeClusterOptions) upgradeCluster(cmd *cobra.Command) error { } validationOpts := &validations.Opts{ - Kubectl: deps.Kubectl, + Kubectl: deps.UnAuthKubectlClient, Spec: clusterSpec, WorkloadCluster: workloadCluster, ManagementCluster: managementCluster, Provider: deps.Provider, CliConfig: cliConfig, } + + if len(uc.skipValidations) != 0 { + validationOpts.SkippedValidations, err = upgradevalidations.ValidateSkippableUpgradeValidation(uc.skipValidations) + if err != nil { + return err + } + } upgradeValidations := upgradevalidations.New(validationOpts) err = upgradeCluster.Run(ctx, clusterSpec, managementCluster, workloadCluster, upgradeValidations, uc.forceClean) diff --git a/cmd/eksctl-anywhere/cmd/validatecreatecluster.go b/cmd/eksctl-anywhere/cmd/validatecreatecluster.go index 9b49c8eda74ca..a755de28b0706 100644 --- a/cmd/eksctl-anywhere/cmd/validatecreatecluster.go +++ b/cmd/eksctl-anywhere/cmd/validatecreatecluster.go @@ -75,6 +75,8 @@ func (valOpt *validateOptions) validateCreateCluster(cmd *cobra.Command, _ []str WithKubectl(). WithProvider(valOpt.fileName, clusterSpec.Cluster, false, valOpt.hardwareCSVPath, true, valOpt.tinkerbellBootstrapIP). WithGitOpsFlux(clusterSpec.Cluster, clusterSpec.FluxConfig, cliConfig). + WithUnAuthKubeClient(). + WithValidatorClients(). Build(ctx) if err != nil { cleanupDirectory(tmpPath) @@ -83,7 +85,7 @@ func (valOpt *validateOptions) validateCreateCluster(cmd *cobra.Command, _ []str defer close(ctx, deps) validationOpts := &validations.Opts{ - Kubectl: deps.Kubectl, + Kubectl: deps.UnAuthKubectlClient, Spec: clusterSpec, WorkloadCluster: &types.Cluster{ Name: clusterSpec.Cluster.Name, diff --git a/pkg/dependencies/factory.go b/pkg/dependencies/factory.go index e301305566420..ce163fa514448 100644 --- a/pkg/dependencies/factory.go +++ b/pkg/dependencies/factory.go @@ -103,6 +103,13 @@ type Dependencies struct { NutanixValidator *nutanix.Validator SnowValidator *snow.Validator IPValidator *validator.IPValidator + UnAuthKubectlClient KubeClients +} + +// KubeClients defines super struct that exposes all behavior. +type KubeClients struct { + *executables.Kubectl + *kubernetes.UnAuthClient } func (d *Dependencies) Close(ctx context.Context) error { @@ -1073,6 +1080,22 @@ func (f *Factory) WithKubeProxyCLIUpgrader() *Factory { return f } +// WithValidatorClients builds KubeClients. +func (f *Factory) WithValidatorClients() *Factory { + f.WithKubectl().WithUnAuthKubeClient() + + f.buildSteps = append(f.buildSteps, func(ctx context.Context) error { + f.dependencies.UnAuthKubectlClient = KubeClients{ + Kubectl: f.dependencies.Kubectl, + UnAuthClient: f.dependencies.UnAuthKubeClient, + } + + return nil + }) + + return f +} + // WithLogger setups a logger to be injected in constructors. It uses the logger // package level logger. func (f *Factory) WithLogger() *Factory { diff --git a/pkg/dependencies/factory_test.go b/pkg/dependencies/factory_test.go index d10681cdcaea9..1a4b20e5626b5 100644 --- a/pkg/dependencies/factory_test.go +++ b/pkg/dependencies/factory_test.go @@ -214,6 +214,7 @@ func TestFactoryBuildWithMultipleDependencies(t *testing.T) { WithCiliumTemplater(). WithIPValidator(). WithKubeProxyCLIUpgrader(). + WithValidatorClients(). Build(context.Background()) tt.Expect(err).To(BeNil()) @@ -235,6 +236,7 @@ func TestFactoryBuildWithMultipleDependencies(t *testing.T) { tt.Expect(deps.CiliumTemplater).NotTo(BeNil()) tt.Expect(deps.IPValidator).NotTo(BeNil()) tt.Expect(deps.KubeProxyCLIUpgrader).NotTo(BeNil()) + tt.Expect(deps.UnAuthKubectlClient).NotTo(BeNil()) } func TestFactoryBuildWithProxyConfiguration(t *testing.T) { diff --git a/pkg/providers/vsphere/workers_test.go b/pkg/providers/vsphere/workers_test.go index 8cecc2aaa4c15..55be5c985b756 100644 --- a/pkg/providers/vsphere/workers_test.go +++ b/pkg/providers/vsphere/workers_test.go @@ -240,7 +240,6 @@ func TestWorkersSpecRegistryMirrorConfiguration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - spec.Cluster.Spec.RegistryMirrorConfiguration = tt.mirrorConfig workers, err := vsphere.WorkersSpec(ctx, logger, client, spec) g.Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/validations/createvalidations/cluster_test.go b/pkg/validations/createvalidations/cluster_test.go index d59a237b0dfeb..0e35217fc16b6 100644 --- a/pkg/validations/createvalidations/cluster_test.go +++ b/pkg/validations/createvalidations/cluster_test.go @@ -12,13 +12,20 @@ import ( "github.com/aws/eks-anywhere/internal/test" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" "github.com/aws/eks-anywhere/pkg/constants" + "github.com/aws/eks-anywhere/pkg/executables" "github.com/aws/eks-anywhere/pkg/validations" "github.com/aws/eks-anywhere/pkg/validations/createvalidations" ) const testclustername string = "testcluster" +type UnAuthKubectlClient struct { + *executables.Kubectl + *kubernetes.UnAuthClient +} + func TestValidateClusterPresent(t *testing.T) { tests := []struct { name string @@ -44,12 +51,13 @@ func TestValidateClusterPresent(t *testing.T) { } k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) cluster.Name = testclustername for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { fileContent := test.ReadFile(t, tc.getClusterResponse) e.EXPECT().Execute(ctx, []string{"get", capiClustersResourceType, "-o", "json", "--kubeconfig", cluster.KubeconfigFile, "--namespace", constants.EksaSystemNamespace}).Return(*bytes.NewBufferString(fileContent), nil) - err := createvalidations.ValidateClusterNameIsUnique(ctx, k, cluster, testclustername) + err := createvalidations.ValidateClusterNameIsUnique(ctx, UnAuthKubectlClient{k, uk}, cluster, testclustername) if !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%v got = %v, \nwant %v", tc.name, err, tc.wantErr) } @@ -93,12 +101,13 @@ func TestValidateManagementClusterCRDs(t *testing.T) { } k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) cluster.Name = testclustername for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { e.EXPECT().Execute(ctx, []string{"get", "customresourcedefinition", capiClustersResourceType, "--kubeconfig", cluster.KubeconfigFile}).Return(bytes.Buffer{}, tc.errGetClusterCRD).Times(tc.errGetClusterCRDCount) e.EXPECT().Execute(ctx, []string{"get", "customresourcedefinition", eksaClusterResourceType, "--kubeconfig", cluster.KubeconfigFile}).Return(bytes.Buffer{}, tc.errGetEKSAClusterCRD).Times(tc.errGetEKSAClusterCRDCount) - err := createvalidations.ValidateManagementCluster(ctx, k, cluster) + err := createvalidations.ValidateManagementCluster(ctx, UnAuthKubectlClient{k, uk}, cluster) if tc.wantErr { assert.Error(tt, err, "expected ValidateManagementCluster to return an error", "test", tc.name) } else { diff --git a/pkg/validations/createvalidations/identityproviders_test.go b/pkg/validations/createvalidations/identityproviders_test.go index a301a623e6adc..8617062a3f876 100644 --- a/pkg/validations/createvalidations/identityproviders_test.go +++ b/pkg/validations/createvalidations/identityproviders_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/eks-anywhere/internal/test" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" "github.com/aws/eks-anywhere/pkg/cluster" "github.com/aws/eks-anywhere/pkg/validations" "github.com/aws/eks-anywhere/pkg/validations/createvalidations" @@ -62,6 +63,7 @@ func TestValidateIdendityProviderForWorkloadClusters(t *testing.T) { s.OIDCConfig = defaultOIDC }) k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) cluster.Name = testclustername for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { @@ -73,7 +75,7 @@ func TestValidateIdendityProviderForWorkloadClusters(t *testing.T) { "--field-selector=metadata.name=oidc-config-test", }).Return(*bytes.NewBufferString(fileContent), nil) - err := createvalidations.ValidateIdentityProviderNameIsUnique(ctx, k, cluster, clusterSpec) + err := createvalidations.ValidateIdentityProviderNameIsUnique(ctx, UnAuthKubectlClient{k, uk}, cluster, clusterSpec) if !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%v got = %v, \nwant %v", tc.name, err, tc.wantErr) } @@ -123,6 +125,7 @@ func TestValidateIdentityProviderForSelfManagedCluster(t *testing.T) { s.Cluster.SetSelfManaged() }) k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) cluster.Name = testclustername for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { @@ -133,7 +136,7 @@ func TestValidateIdentityProviderForSelfManagedCluster(t *testing.T) { "--field-selector=metadata.name=oidc-config-test", }).Times(0) - err := createvalidations.ValidateIdentityProviderNameIsUnique(ctx, k, cluster, clusterSpec) + err := createvalidations.ValidateIdentityProviderNameIsUnique(ctx, UnAuthKubectlClient{k, uk}, cluster, clusterSpec) if !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%v got = %v, \nwant %v", tc.name, err, tc.wantErr) } diff --git a/pkg/validations/createvalidations/preflightvalidations.go b/pkg/validations/createvalidations/preflightvalidations.go index ec2ec6dd4db35..97a8693c737c9 100644 --- a/pkg/validations/createvalidations/preflightvalidations.go +++ b/pkg/validations/createvalidations/preflightvalidations.go @@ -32,7 +32,7 @@ func (v *CreateValidations) PreflightValidations(ctx context.Context) []validati return &validations.ValidationResult{ Name: "validate certificate for registry mirror", Remediation: fmt.Sprintf("provide a valid certificate for you registry endpoint using %s env var", anywherev1.RegistryMirrorCAKey), - Err: validations.ValidateCertForRegistryMirror(v.Opts.Spec, v.Opts.TlsValidator), + Err: validations.ValidateCertForRegistryMirror(v.Opts.Spec, v.Opts.TLSValidator), } }, func() *validations.ValidationResult { diff --git a/pkg/validations/kubectl.go b/pkg/validations/kubectl.go index f3edad7432f1b..5d352565c2cc4 100644 --- a/pkg/validations/kubectl.go +++ b/pkg/validations/kubectl.go @@ -8,6 +8,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" "github.com/aws/eks-anywhere/pkg/executables" mockexecutables "github.com/aws/eks-anywhere/pkg/executables/mocks" "github.com/aws/eks-anywhere/pkg/types" @@ -15,6 +16,7 @@ import ( ) type KubectlClient interface { + List(ctx context.Context, kubeconfig string, list kubernetes.ObjectList) error ValidateControlPlaneNodes(ctx context.Context, cluster *types.Cluster, clusterName string) error ValidateWorkerNodes(ctx context.Context, clusterName string, kubeconfig string) error ValidateNodes(ctx context.Context, kubeconfig string) error diff --git a/pkg/validations/mocks/kubectl.go b/pkg/validations/mocks/kubectl.go index 2346c8e0cba06..1c60831b219e8 100644 --- a/pkg/validations/mocks/kubectl.go +++ b/pkg/validations/mocks/kubectl.go @@ -9,6 +9,7 @@ import ( reflect "reflect" v1alpha1 "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + kubernetes "github.com/aws/eks-anywhere/pkg/clients/kubernetes" executables "github.com/aws/eks-anywhere/pkg/executables" types "github.com/aws/eks-anywhere/pkg/types" v1alpha10 "github.com/aws/eks-anywhere/release/api/v1alpha1" @@ -203,6 +204,20 @@ func (mr *MockKubectlClientMockRecorder) GetObject(ctx, resourceType, name, name return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObject", reflect.TypeOf((*MockKubectlClient)(nil).GetObject), ctx, resourceType, name, namespace, kubeconfig, obj) } +// List mocks base method. +func (m *MockKubectlClient) List(ctx context.Context, kubeconfig string, list kubernetes.ObjectList) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx, kubeconfig, list) + ret0, _ := ret[0].(error) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockKubectlClientMockRecorder) List(ctx, kubeconfig, list interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockKubectlClient)(nil).List), ctx, kubeconfig, list) +} + // SearchIdentityProviderConfig mocks base method. func (m *MockKubectlClient) SearchIdentityProviderConfig(ctx context.Context, ipName, kind, kubeconfigFile, namespace string) ([]*v1alpha1.VSphereDatacenterConfig, error) { m.ctrl.T.Helper() diff --git a/pkg/validations/upgradevalidations/cluster_test.go b/pkg/validations/upgradevalidations/cluster_test.go index ce8c9795f4b11..47b722ddbb61e 100644 --- a/pkg/validations/upgradevalidations/cluster_test.go +++ b/pkg/validations/upgradevalidations/cluster_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/eks-anywhere/internal/test" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" "github.com/aws/eks-anywhere/pkg/constants" "github.com/aws/eks-anywhere/pkg/validations" "github.com/aws/eks-anywhere/pkg/validations/upgradevalidations" @@ -43,12 +44,14 @@ func TestValidateClusterPresent(t *testing.T) { } k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) + cluster.Name = testclustername for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { fileContent := test.ReadFile(t, tc.getClusterResponse) e.EXPECT().Execute(ctx, []string{"get", capiClustersResourceType, "-o", "json", "--kubeconfig", cluster.KubeconfigFile, "--namespace", constants.EksaSystemNamespace}).Return(*bytes.NewBufferString(fileContent), nil) - err := upgradevalidations.ValidateClusterObjectExists(ctx, k, cluster) + err := upgradevalidations.ValidateClusterObjectExists(ctx, UnAuthKubectlClient{k, uk}, cluster) if !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%v got = %v, \nwant %v", tc.name, err, tc.wantErr) } diff --git a/pkg/validations/upgradevalidations/poddisruptionbudgets.go b/pkg/validations/upgradevalidations/poddisruptionbudgets.go new file mode 100644 index 0000000000000..811e20871ff0f --- /dev/null +++ b/pkg/validations/upgradevalidations/poddisruptionbudgets.go @@ -0,0 +1,29 @@ +package upgradevalidations + +import ( + "context" + "fmt" + + "github.com/pkg/errors" + policy "k8s.io/api/policy/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + + "github.com/aws/eks-anywhere/pkg/types" + "github.com/aws/eks-anywhere/pkg/validations" +) + +// ValidatePodDisruptionBudgets returns an error if any pdbs are detected on a cluster. +func ValidatePodDisruptionBudgets(ctx context.Context, k validations.KubectlClient, cluster *types.Cluster) error { + podDisruptionBudgets := &policy.PodDisruptionBudgetList{} + if err := k.List(ctx, cluster.KubeconfigFile, podDisruptionBudgets); err != nil { + if !apierrors.IsNotFound(err) { + return errors.Wrap(err, "listing cluster pod disruption budgets for upgrade") + } + } + + if len(podDisruptionBudgets.Items) != 0 { + return fmt.Errorf("one or more pod disruption budgets were detected on the cluster. Use the --skip-validations=%s flag if you wish to skip the validations for pod disruption budgets and proceed with the upgrade operation", PDB) + } + + return nil +} diff --git a/pkg/validations/upgradevalidations/poddisruptionbudgets_test.go b/pkg/validations/upgradevalidations/poddisruptionbudgets_test.go new file mode 100644 index 0000000000000..0fd97acffb67a --- /dev/null +++ b/pkg/validations/upgradevalidations/poddisruptionbudgets_test.go @@ -0,0 +1,103 @@ +package upgradevalidations_test + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/golang/mock/gomock" + policy "k8s.io/api/policy/v1" + "k8s.io/apimachinery/pkg/util/intstr" + + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" + "github.com/aws/eks-anywhere/pkg/types" + "github.com/aws/eks-anywhere/pkg/validations" + "github.com/aws/eks-anywhere/pkg/validations/mocks" + "github.com/aws/eks-anywhere/pkg/validations/upgradevalidations" +) + +func TestValidatePodDisruptionBudgets(t *testing.T) { + type args struct { + ctx context.Context + k validations.KubectlClient + cluster *types.Cluster + pdbList *policy.PodDisruptionBudgetList + } + mockCtrl := gomock.NewController(t) + k := mocks.NewMockKubectlClient(mockCtrl) + c := types.Cluster{ + KubeconfigFile: "test.kubeconfig", + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "PDBs exist on cluster", + args: args{ + ctx: context.Background(), + k: k, + cluster: &c, + pdbList: &policy.PodDisruptionBudgetList{ + Items: []policy.PodDisruptionBudget{ + { + Spec: policy.PodDisruptionBudgetSpec{ + MinAvailable: &intstr.IntOrString{ + Type: intstr.Int, + IntVal: 0, + }, + }, + }, + }, + }, + }, + wantErr: fmt.Errorf("one or more pod disruption budgets were detected on the cluster. Use the --skip-validations=%s flag if you wish to skip the validations for pod disruption budgets and proceed with the upgrade operation", upgradevalidations.PDB), + }, + { + name: "PDBs don't exist on cluster", + args: args{ + ctx: context.Background(), + k: k, + cluster: &c, + pdbList: &policy.PodDisruptionBudgetList{}, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + podDisruptionBudgets := &policy.PodDisruptionBudgetList{} + k.EXPECT().List(tt.args.ctx, tt.args.cluster.KubeconfigFile, podDisruptionBudgets).DoAndReturn(func(_ context.Context, _ string, objs kubernetes.ObjectList) error { + tt.args.pdbList.DeepCopyInto(objs.(*policy.PodDisruptionBudgetList)) + return nil + }) + + t.Run(tt.name, func(t *testing.T) { + if err := upgradevalidations.ValidatePodDisruptionBudgets(tt.args.ctx, tt.args.k, tt.args.cluster); !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ValidatePodDisruptionBudgets() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidatePodDisruptionBudgetsFailure(t *testing.T) { + mockCtrl := gomock.NewController(t) + k := mocks.NewMockKubectlClient(mockCtrl) + c := types.Cluster{ + KubeconfigFile: "test.kubeconfig", + } + ctx := context.Background() + pdbList := &policy.PodDisruptionBudgetList{} + + k.EXPECT().List(ctx, c.KubeconfigFile, pdbList).Return(errors.New("listing cluster pod disruption budgets for upgrade")) + + wantErr := errors.New("listing cluster pod disruption budgets for upgrade") + + err := upgradevalidations.ValidatePodDisruptionBudgets(ctx, k, &c) + if err != nil && !strings.Contains(err.Error(), wantErr.Error()) { + t.Errorf("ValidatePodDisruptionBudgets() error = %v, wantErr %v", err, wantErr) + } +} diff --git a/pkg/validations/upgradevalidations/preflightvalidations.go b/pkg/validations/upgradevalidations/preflightvalidations.go index 6b8bb355be01a..a08c0e4472ae2 100644 --- a/pkg/validations/upgradevalidations/preflightvalidations.go +++ b/pkg/validations/upgradevalidations/preflightvalidations.go @@ -33,7 +33,7 @@ func (u *UpgradeValidations) PreflightValidations(ctx context.Context) []validat return &validations.ValidationResult{ Name: "validate certificate for registry mirror", Remediation: fmt.Sprintf("provide a valid certificate for you registry endpoint using %s env var", anywherev1.RegistryMirrorCAKey), - Err: validations.ValidateCertForRegistryMirror(u.Opts.Spec, u.Opts.TlsValidator), + Err: validations.ValidateCertForRegistryMirror(u.Opts.Spec, u.Opts.TLSValidator), } }, func() *validations.ValidationResult { @@ -113,5 +113,17 @@ func (u *UpgradeValidations) PreflightValidations(ctx context.Context) []validat } }) } + + if !u.Opts.SkippedValidations[PDB] { + upgradeValidations = append( + upgradeValidations, + func() *validations.ValidationResult { + return &validations.ValidationResult{ + Name: "validate pod disruption budgets", + Remediation: "", + Err: ValidatePodDisruptionBudgets(ctx, k, targetCluster), + } + }) + } return upgradeValidations } diff --git a/pkg/validations/upgradevalidations/preflightvalidations_test.go b/pkg/validations/upgradevalidations/preflightvalidations_test.go index 9ec50c93cb83b..535918c99d69e 100644 --- a/pkg/validations/upgradevalidations/preflightvalidations_test.go +++ b/pkg/validations/upgradevalidations/preflightvalidations_test.go @@ -381,7 +381,7 @@ func TestPreflightValidationsTinkerbell(t *testing.T) { WorkloadCluster: workloadCluster, ManagementCluster: workloadCluster, Provider: provider, - TlsValidator: tlsValidator, + TLSValidator: tlsValidator, } clusterSpec.Cluster.Spec.KubernetesVersion = v1alpha1.KubernetesVersion(tc.upgradeVersion) @@ -410,6 +410,7 @@ func TestPreflightValidationsTinkerbell(t *testing.T) { // provider.EXPECT().ValidateNewSpec(ctx, workloadCluster, clusterSpec).Return(nil).MaxTimes(1) kubectl.EXPECT().GetEksaTinkerbellDatacenterConfig(ctx, clusterSpec.Cluster.Spec.DatacenterRef.Name, gomock.Any(), gomock.Any()).Return(existingProviderSpec, nil).MaxTimes(1) kubectl.EXPECT().GetEksaTinkerbellMachineConfig(ctx, clusterSpec.Cluster.Spec.ControlPlaneConfiguration.MachineGroupRef.Name, gomock.Any(), gomock.Any()).Return(existingMachineConfigSpec, nil).MaxTimes(1) + k.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) k.EXPECT().ValidateControlPlaneNodes(ctx, workloadCluster, clusterSpec.Cluster.Name).Return(tc.cpResponse) k.EXPECT().ValidateWorkerNodes(ctx, workloadCluster.Name, workloadCluster.KubeconfigFile).Return(tc.workerResponse) k.EXPECT().ValidateNodes(ctx, kubeconfigFilePath).Return(tc.nodeResponse) @@ -1170,7 +1171,7 @@ func TestPreflightValidationsVsphere(t *testing.T) { WorkloadCluster: workloadCluster, ManagementCluster: workloadCluster, Provider: provider, - TlsValidator: tlsValidator, + TLSValidator: tlsValidator, } clusterSpec.Cluster.Spec.KubernetesVersion = v1alpha1.KubernetesVersion(tc.upgradeVersion) @@ -1194,6 +1195,7 @@ func TestPreflightValidationsVsphere(t *testing.T) { provider.EXPECT().ValidateNewSpec(ctx, workloadCluster, clusterSpec).Return(nil).MaxTimes(1) k.EXPECT().GetEksaVSphereDatacenterConfig(ctx, clusterSpec.Cluster.Spec.DatacenterRef.Name, gomock.Any(), gomock.Any()).Return(existingProviderSpec, nil).MaxTimes(1) k.EXPECT().ValidateControlPlaneNodes(ctx, workloadCluster, clusterSpec.Cluster.Name).Return(tc.cpResponse) + k.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) k.EXPECT().ValidateWorkerNodes(ctx, workloadCluster.Name, workloadCluster.KubeconfigFile).Return(tc.workerResponse) k.EXPECT().ValidateNodes(ctx, kubeconfigFilePath).Return(tc.nodeResponse) k.EXPECT().ValidateClustersCRD(ctx, workloadCluster).Return(tc.crdResponse) @@ -1412,7 +1414,7 @@ func TestPreFlightValidationsGit(t *testing.T) { WorkloadCluster: workloadCluster, ManagementCluster: workloadCluster, Provider: provider, - TlsValidator: tlsValidator, + TLSValidator: tlsValidator, CliConfig: cliConfig, } @@ -1430,6 +1432,7 @@ func TestPreFlightValidationsGit(t *testing.T) { provider.EXPECT().DatacenterConfig(clusterSpec).Return(existingProviderSpec).MaxTimes(1) provider.EXPECT().ValidateNewSpec(ctx, workloadCluster, clusterSpec).Return(nil).MaxTimes(1) + k.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) k.EXPECT().GetEksaVSphereDatacenterConfig(ctx, clusterSpec.Cluster.Spec.DatacenterRef.Name, gomock.Any(), gomock.Any()).Return(existingProviderSpec, nil).MaxTimes(1) k.EXPECT().ValidateControlPlaneNodes(ctx, workloadCluster, clusterSpec.Cluster.Name).Return(tc.cpResponse) k.EXPECT().ValidateWorkerNodes(ctx, workloadCluster.Name, workloadCluster.KubeconfigFile).Return(tc.workerResponse) diff --git a/pkg/validations/upgradevalidations/upgradeskipvalidations.go b/pkg/validations/upgradevalidations/upgradeskipvalidations.go new file mode 100644 index 0000000000000..5780e5ef63418 --- /dev/null +++ b/pkg/validations/upgradevalidations/upgradeskipvalidations.go @@ -0,0 +1,43 @@ +package upgradevalidations + +import ( + "fmt" + "strings" +) + +// string values of supported validation names that can be skipped. +const ( + PDB = "pod-disruption" +) + +// SkippableValidations represents all the validations we offer for users to skip. +var SkippableValidations = []string{ + PDB, +} + +// ValidSkippableValidationsMap returns a map for all valid skippable validations as keys, defaulting values to false. Defaulting to False means these validations won't be skipped unless set to True. +func validSkippableValidationsMap() map[string]bool { + validationsMap := make(map[string]bool, len(SkippableValidations)) + + for i := range SkippableValidations { + validationsMap[SkippableValidations[i]] = false + } + + return validationsMap +} + +// ValidateSkippableUpgradeValidation validates if provided validations are supported by EKSA to skip for upgrades. +func ValidateSkippableUpgradeValidation(skippedValidations []string) (map[string]bool, error) { + svMap := validSkippableValidationsMap() + + for i := range skippedValidations { + validationName := skippedValidations[i] + _, ok := svMap[validationName] + if !ok { + return nil, fmt.Errorf("invalid validation name to be skipped. The supported upgrade validations that can be skipped using --skip-validations are %s", strings.Join(SkippableValidations[:], ",")) + } + svMap[validationName] = true + } + + return svMap, nil +} diff --git a/pkg/validations/upgradevalidations/upgradeskipvalidations_test.go b/pkg/validations/upgradevalidations/upgradeskipvalidations_test.go new file mode 100644 index 0000000000000..7ce999b8e26a3 --- /dev/null +++ b/pkg/validations/upgradevalidations/upgradeskipvalidations_test.go @@ -0,0 +1,46 @@ +package upgradevalidations_test + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/aws/eks-anywhere/pkg/validations/upgradevalidations" +) + +func TestValidateSkippableUpgradeValidation(t *testing.T) { + tests := []struct { + name string + want map[string]bool + wantErr error + skippedValidations []string + }{ + { + name: "invalid upgrade validation param", + want: nil, + wantErr: fmt.Errorf("invalid validation name to be skipped. The supported upgrade validations that can be skipped using --skip-validations are %s", strings.Join(upgradevalidations.SkippableValidations[:], ",")), + skippedValidations: []string{"test"}, + }, + { + name: "valid upgrade validation param", + want: map[string]bool{ + upgradevalidations.PDB: true, + }, + wantErr: nil, + skippedValidations: []string{upgradevalidations.PDB}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := upgradevalidations.ValidateSkippableUpgradeValidation(tt.skippedValidations) + if !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ValidateSkippableUpgradeValidation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ValidateSkippableUpgradeValidation() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/validations/upgradevalidations/versions_test.go b/pkg/validations/upgradevalidations/versions_test.go index c466c842efb11..eb8ff407460e4 100644 --- a/pkg/validations/upgradevalidations/versions_test.go +++ b/pkg/validations/upgradevalidations/versions_test.go @@ -8,10 +8,17 @@ import ( "github.com/aws/eks-anywhere/internal/test" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" + "github.com/aws/eks-anywhere/pkg/clients/kubernetes" + "github.com/aws/eks-anywhere/pkg/executables" "github.com/aws/eks-anywhere/pkg/validations" "github.com/aws/eks-anywhere/pkg/validations/upgradevalidations" ) +type UnAuthKubectlClient struct { + *executables.Kubectl + *kubernetes.UnAuthClient +} + func TestValidateVersionSkew(t *testing.T) { tests := []struct { name string @@ -46,11 +53,13 @@ func TestValidateVersionSkew(t *testing.T) { } k, ctx, cluster, e := validations.NewKubectl(t) + uk := kubernetes.NewUnAuthClient(k) + for _, tc := range tests { t.Run(tc.name, func(tt *testing.T) { fileContent := test.ReadFile(t, tc.serverVersionResponse) e.EXPECT().Execute(ctx, []string{"version", "-o", "json", "--kubeconfig", cluster.KubeconfigFile}).Return(*bytes.NewBufferString(fileContent), nil) - err := upgradevalidations.ValidateServerVersionSkew(ctx, tc.upgradeVersion, cluster, k) + err := upgradevalidations.ValidateServerVersionSkew(ctx, tc.upgradeVersion, cluster, UnAuthKubectlClient{k, uk}) if !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%v got = %v, \nwant %v", tc.name, err, tc.wantErr) } diff --git a/pkg/validations/validation_options.go b/pkg/validations/validation_options.go index bae5c7467b602..8b33f30c96034 100644 --- a/pkg/validations/validation_options.go +++ b/pkg/validations/validation_options.go @@ -9,17 +9,18 @@ import ( ) type Opts struct { - Kubectl KubectlClient - Spec *cluster.Spec - WorkloadCluster *types.Cluster - ManagementCluster *types.Cluster - Provider providers.Provider - TlsValidator TlsValidator - CliConfig *config.CliConfig + Kubectl KubectlClient + Spec *cluster.Spec + WorkloadCluster *types.Cluster + ManagementCluster *types.Cluster + Provider providers.Provider + TLSValidator TlsValidator + CliConfig *config.CliConfig + SkippedValidations map[string]bool } func (o *Opts) SetDefaults() { - if o.TlsValidator == nil { - o.TlsValidator = crypto.NewTlsValidator() + if o.TLSValidator == nil { + o.TLSValidator = crypto.NewTlsValidator() } }