Skip to content

Commit

Permalink
fix validation for driver calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-11 committed May 21, 2024
1 parent 295ac09 commit c97784b
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 89 deletions.
12 changes: 8 additions & 4 deletions pkg/gcp/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
api "github.com/gardener/machine-controller-manager-provider-gcp/pkg/api/v1alpha1"
"github.com/gardener/machine-controller-manager-provider-gcp/pkg/gcp"
"github.com/gardener/machine-controller-manager-provider-gcp/pkg/gcp/errors"
validation "github.com/gardener/machine-controller-manager-provider-gcp/pkg/gcp/validation"
"github.com/gardener/machine-controller-manager-provider-gcp/pkg/gcp/validation"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/yaml"
)
Expand Down Expand Up @@ -49,9 +49,13 @@ func TestPluginSPIImpl(t *testing.T) {
ms := gcp.NewGCPPlugin(&gcp.PluginSPIImpl{})
ctx := context.TODO()

ValidationErr := validation.ValidateGCPProviderSpec(cfg.ProviderSpec, cfg.Secrets)
if ValidationErr != nil {
t.Errorf("Error while validating ProviderSpec %v", ValidationErr)
validationErr := validation.ValidateProviderSpec(cfg.ProviderSpec)
if validationErr != nil {
t.Errorf("Error while validating ProviderSpec %v", validationErr)
return
}
if validationErr = validation.ValidateSecret(cfg.Secrets); validationErr != nil {
t.Errorf("Error while validating Secret %v", validationErr)
return
}

Expand Down
32 changes: 20 additions & 12 deletions pkg/gcp/machine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ func (ms *MachinePlugin) CreateMachine(ctx context.Context, req *driver.CreateMa
return nil, status.Error(codes.InvalidArgument, err.Error())
}

providerSpec, err := decodeProviderSpecAndSecret(req.MachineClass, req.Secret)
providerSpec, err := decodeProviderSpecAndValidateSecret(req.MachineClass, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Create machine %q failed on decodeProviderSpecAndSecret", req.Machine.Name)
return nil, prepareErrorf(err, "Create machine %q failed on decodeProviderSpecAndValidateSecret", req.Machine.Name)
}
if err = validateProviderSpec(providerSpec); err != nil {
return nil, prepareErrorf(err, "Create machine %q failed on validateProviderSpec", req.Machine.Name)
}

providerID, err := ms.CreateMachineUtil(ctx, req.Machine.Name, providerSpec, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Create machine %q failed", req.Machine.Name)
Expand Down Expand Up @@ -123,11 +125,13 @@ func (ms *MachinePlugin) DeleteMachine(ctx context.Context, req *driver.DeleteMa
return nil, status.Error(codes.InvalidArgument, err.Error())
}

providerSpec, err := decodeProviderSpecAndSecret(req.MachineClass, req.Secret)
providerSpec, err := decodeProviderSpecAndValidateSecret(req.MachineClass, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Delete machine %q failed on decodeProviderSpecAndSecret", req.Machine.Name)
return nil, prepareErrorf(err, "Delete machine %q failed on decodeProviderSpecAndValidateSecret", req.Machine.Name)
}
if err = validateZone(providerSpec.Zone); err != nil {
return nil, prepareErrorf(err, "Delete machine %q failed on validateZone", req.Machine.Name)
}

providerID, err := ms.DeleteMachineUtil(ctx, req.Machine.Name, req.Machine.Spec.ProviderID, providerSpec, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Delete machine %q failed", req.Machine.Name)
Expand Down Expand Up @@ -166,11 +170,13 @@ func (ms *MachinePlugin) GetMachineStatus(ctx context.Context, req *driver.GetMa
return nil, status.Error(codes.InvalidArgument, err.Error())
}

providerSpec, err := decodeProviderSpecAndSecret(req.MachineClass, req.Secret)
providerSpec, err := decodeProviderSpecAndValidateSecret(req.MachineClass, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Machine status %q failed on decodeProviderSpecAndSecret", req.Machine.Name)
return nil, prepareErrorf(err, "Machine status %q failed on decodeProviderSpecAndValidateSecret", req.Machine.Name)
}
if err = validateZone(providerSpec.Zone); err != nil {
return nil, prepareErrorf(err, "Machine status %q failed on validateZone", req.Machine.Name)
}

providerID, err := ms.GetMachineStatusUtil(ctx, req.Machine.Name, req.Machine.Spec.ProviderID, providerSpec, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "Machine status %q failed", req.Machine.Name)
Expand Down Expand Up @@ -209,11 +215,13 @@ func (ms *MachinePlugin) ListMachines(ctx context.Context, req *driver.ListMachi
return nil, status.Error(codes.InvalidArgument, err.Error())
}

providerSpec, err := decodeProviderSpecAndSecret(req.MachineClass, req.Secret)
providerSpec, err := decodeProviderSpecAndValidateSecret(req.MachineClass, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "List machines failed on decodeProviderSpecAndSecret")
return nil, prepareErrorf(err, "List machines failed on decodeProviderSpecAndValidateSecret")
}
if err = validateZone(providerSpec.Zone); err != nil {
return nil, prepareErrorf(err, "List machines failed on validateZone")
}

machineList, err := ms.ListMachinesUtil(ctx, providerSpec, req.Secret)
if err != nil {
return nil, prepareErrorf(err, "List machines failed")
Expand Down
22 changes: 9 additions & 13 deletions pkg/gcp/machine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,20 @@ import (
)

const (
// TestNamaspace is the test namespace used
TestNamaspace string = "test"
// TestClassName is the test class name used
TestMachineClassName string = "test-mc"
// FailAtNotFound is the error message returned when a resource is not found
FailAtNotFound string = "machine codes error: code = [NotFound] message = [machine name=non-existent-dummy-machine, uuid= not found]"
// FailAtJSONUnmarshalling is the error message returned when an malformed JSON is sent to the plugin by the caller
FailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Machine status \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
FailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Machine status \"dummy-machine\" failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
// CreateFailAtJSONUnmarshalling is the error message returned when an malformed JSON is sent to the plugin by the caller
CreateFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
CreateFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
// DeleteFailAtJSONUnmarshalling is the error message returned when an malformed JSON is sent to the plugin by the caller
DeleteFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Delete machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
DeleteFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [Delete machine \"dummy-machine\" failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
// ListFailAtJSONUnmarshalling is the error message returned when an malformed JSON is sent to the plugin by the caller
ListFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [List machines failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
ListFailAtJSONUnmarshalling string = "machine codes error: code = [Internal] message = [List machines failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [unexpected end of JSON input]]"
// FailAtNoSecretsPassed is the error message returned when no secrets are passed to the the plugin by the caller
FailAtNoSecretsPassed string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [Error while validating ProviderSpec [secret serviceAccountJSON or serviceaccount.json is required field secret userData is required field]]]"
FailAtNoSecretsPassed string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [error while validating Secret [secret serviceAccountJSON or serviceaccount.json is required field secret userData is required field]]]"
// FailAtSecretsWithNoUserData is the error message returned when secrets map has no userdata provided by the caller
FailAtSecretsWithNoUserData string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [Error while validating ProviderSpec [secret userData is required field]]]"
FailAtSecretsWithNoUserData string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndValidateSecret: machine codes error: code = [Internal] message = [error while validating Secret [secret userData is required field]]]"
// FailAtInvalidProjectID is the error returned when an invalid project id value is provided by the caller
FailAtInvalidProjectID string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed: json: cannot unmarshal number into Go struct field .project_id of type string]"
// FailAtInvalidZonePostCall is the error returned when a post call should fail with an invalid zone is sent in the POST call -- this is used to simulate server error
Expand All @@ -54,13 +50,13 @@ const (
// FailAtMethodNotImplemented is the error returned for methods which are not yet implemented
FailAtMethodNotImplemented string = "rpc error: code = Unimplemented desc = "
// FailAtSpecValidation fails at spec validation
FailAtSpecValidation string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [Error while validating ProviderSpec [spec.zone: Required value: zone is required]]]"
FailAtSpecValidation string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on validateProviderSpec: machine codes error: code = [Internal] message = [error while validating ProviderSpec [spec.zone: Required value: zone is required]]]"
// FailAtNonExistingMachine because existing machine is not found
FailAtNonExistingMachine string = "rpc error: code = NotFound desc = Machine with the name \"non-existent-dummy-machine\" not found"
// FailAtSpecValidationNoKmsKeyName if kmsKeyName missing
FailAtSpecValidationNoKmsKeyName string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [Error while validating ProviderSpec [spec.disks[0].kmsKeyName: Required value: kmsKeyName is required to be specified]]]"
FailAtSpecValidationNoKmsKeyName string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on validateProviderSpec: machine codes error: code = [Internal] message = [error while validating ProviderSpec [spec.disks[0].kmsKeyName: Required value: kmsKeyName is required to be specified]]]"
// FailAtSpecValidationInvalidKmsServiceAccount if kmsKeyServiceAccount invalid
FailAtSpecValidationInvalidKmsServiceAccount string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on decodeProviderSpecAndSecret: machine codes error: code = [Internal] message = [Error while validating ProviderSpec [spec.disks[0].kmsKeyServiceAccount: Required value: kmsKeyServiceAccount should either be explicitly specified without spaces or left un-specified to default to the Compute Service Agent]]]"
FailAtSpecValidationInvalidKmsServiceAccount string = "machine codes error: code = [Internal] message = [Create machine \"dummy-machine\" failed on validateProviderSpec: machine codes error: code = [Internal] message = [error while validating ProviderSpec [spec.disks[0].kmsKeyServiceAccount: Required value: kmsKeyServiceAccount should either be explicitly specified without spaces or left un-specified to default to the Compute Service Agent]]]"

UnsupportedProviderError string = "machine codes error: code = [InvalidArgument] message = [Requested for Provider 'aws', we only support 'GCP']"
)
Expand Down
54 changes: 25 additions & 29 deletions pkg/gcp/machine_controller_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,8 @@ import (
)

const (
// GCPProviderPrefix is the prefix used by the GCP provider
GCPProviderPrefix = "gce://"

// GCPMachineClassKind for GCPMachineClass
GCPMachineClassKind = "GCPMachineClass"

// MachineClassKind for MachineClass
MachineClassKind = "MachineClass"
// ProviderPrefix is the prefix used by the GCP provider
ProviderPrefix = "gce://"
)

// CreateMachineUtil method is used to create a GCP machine
Expand Down Expand Up @@ -182,21 +176,7 @@ func encodeMachineID(project, zone, name string) string {
if name == "" {
return ""
}
return fmt.Sprintf("%s/%s/%s/%s", GCPProviderPrefix, project, zone, name)
}

func decodeMachineID(id string) (string, string, string, error) {
gceSplit := strings.Split(id, "gce:///")
if len(gceSplit) != 2 {
return "", "", "", fmt.Errorf("Invalid format of machine id: %s", id)
}

gce := strings.Split(gceSplit[1], "/")
if len(gce) != 3 {
return "", "", "", fmt.Errorf("Invalid format of machine id: %s", id)
}

return gce[0], gce[1], gce[2], nil
return fmt.Sprintf("%s/%s/%s/%s", ProviderPrefix, project, zone, name)
}

// DeleteMachineUtil deletes a VM by name
Expand Down Expand Up @@ -329,8 +309,8 @@ func getVMs(ctx context.Context, machineID string, providerSpec *api.GCPProvider
return listOfVMs, nil
}

// decodeProviderSpecAndSecret converts request parameters to api.ProviderSpec
func decodeProviderSpecAndSecret(machineClass *v1alpha1.MachineClass, secret *corev1.Secret) (*api.GCPProviderSpec, error) {
// decodeProviderSpecAndValidateSecret converts request parameters to api.ProviderSpec
func decodeProviderSpecAndValidateSecret(machineClass *v1alpha1.MachineClass, secret *corev1.Secret) (*api.GCPProviderSpec, error) {
var providerSpec *api.GCPProviderSpec

// If machineClass is nil
Expand All @@ -344,16 +324,32 @@ func decodeProviderSpecAndSecret(machineClass *v1alpha1.MachineClass, secret *co
return nil, status.Error(codes.Internal, err.Error())
}

// Validate the Spec and Secrets
ValidationErr := validation.ValidateGCPProviderSpec(providerSpec, secret)
if ValidationErr != nil {
err = fmt.Errorf("Error while validating ProviderSpec %v", ValidationErr)
// Validate the Secret
validationErr := validation.ValidateSecret(secret)
if validationErr != nil {
err = fmt.Errorf("error while validating Secret %v", validationErr)
return nil, status.Error(codes.Internal, err.Error())
}

return providerSpec, nil
}

func validateProviderSpec(providerSpec *api.GCPProviderSpec) error {
if validationErr := validation.ValidateProviderSpec(providerSpec); validationErr != nil {
err := fmt.Errorf("error while validating ProviderSpec %v", validationErr)
return status.Error(codes.Internal, err.Error())
}
return nil
}

func validateZone(zone string) error {
if err := validation.ValidateZone(zone); err != nil {
err = fmt.Errorf("error while validating Zone %v", err)
return status.Error(codes.InvalidArgument, err.Error())
}
return nil
}

func prepareErrorf(err error, format string, args ...interface{}) error {
var (
code codes.Code
Expand Down
61 changes: 30 additions & 31 deletions pkg/gcp/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,39 @@ const (
DiskInterfaceSCSI = "SCSI"
)

// ValidateGCPProviderSpec validates gcp provider spec
func ValidateGCPProviderSpec(spec *api.GCPProviderSpec, secrets *corev1.Secret) []error {
allErrs := validateGCPMachineClassSpec(spec, field.NewPath("spec"))
allErrs = append(allErrs, validateSecrets(secrets)...)
// ValidateProviderSpec validates gcp provider spec
func ValidateProviderSpec(spec *api.GCPProviderSpec) []error {
fldPath := field.NewPath("spec")
var allErrs []error

allErrs = append(allErrs, validateGCPDisks(spec.Disks, fldPath.Child("disks"))...)

if "" == spec.MachineType {
allErrs = append(allErrs, field.Required(fldPath.Child("machineType"), "machineType is required"))
}
if "" == spec.Region {
allErrs = append(allErrs, field.Required(fldPath.Child("region"), "region is required"))
}
if "" == spec.Zone {
allErrs = append(allErrs, field.Required(fldPath.Child("zone"), "zone is required"))
}

allErrs = append(allErrs, validateGCPNetworkInterfaces(spec.NetworkInterfaces, fldPath.Child("networkInterfaces"))...)
allErrs = append(allErrs, validateGCPMetadata(spec.Metadata, fldPath.Child("networkInterfaces"))...)
allErrs = append(allErrs, validateGCPGpu(spec.Gpu, fldPath.Child("gpu"))...)
allErrs = append(allErrs, validateGCPScheduling(spec.Scheduling, spec.Gpu, fldPath.Child("scheduling"))...)

return allErrs
}

func validateSecrets(secret *corev1.Secret) []error {
func ValidateZone(zone string) error {
if zone == "" {
return fmt.Errorf("zone cannot be empty")
}
return nil
}

func ValidateSecret(secret *corev1.Secret) []error {
var allErrs []error

if secret == nil {
Expand All @@ -51,29 +76,6 @@ func validateSecrets(secret *corev1.Secret) []error {
return allErrs
}

func validateGCPMachineClassSpec(spec *api.GCPProviderSpec, fldPath *field.Path) []error {
var allErrs []error

allErrs = append(allErrs, validateGCPDisks(spec.Disks, fldPath.Child("disks"))...)

if "" == spec.MachineType {
allErrs = append(allErrs, field.Required(fldPath.Child("machineType"), "machineType is required"))
}
if "" == spec.Region {
allErrs = append(allErrs, field.Required(fldPath.Child("region"), "region is required"))
}
if "" == spec.Zone {
allErrs = append(allErrs, field.Required(fldPath.Child("zone"), "zone is required"))
}

allErrs = append(allErrs, validateGCPNetworkInterfaces(spec.NetworkInterfaces, fldPath.Child("networkInterfaces"))...)
allErrs = append(allErrs, validateGCPMetadata(spec.Metadata, fldPath.Child("networkInterfaces"))...)
allErrs = append(allErrs, validateGCPGpu(spec.Gpu, fldPath.Child("gpu"))...)
allErrs = append(allErrs, validateGCPScheduling(spec.Scheduling, spec.Gpu, fldPath.Child("scheduling"))...)

return allErrs
}

func validateGCPDisks(disks []*api.GCPDisk, fldPath *field.Path) []error {
var allErrs []error

Expand All @@ -83,9 +85,6 @@ func validateGCPDisks(disks []*api.GCPDisk, fldPath *field.Path) []error {

for i, disk := range disks {
idxPath := fldPath.Index(i)
if disk.SizeGb < 20 {
allErrs = append(allErrs, field.Invalid(idxPath.Child("sizeGb"), disk.SizeGb, "disk size must be at least 20 GB"))
}
if disk.Type == DiskTypeScratch && (disk.Interface != DiskInterfaceNVME && disk.Interface != DiskInterfaceSCSI) {
allErrs = append(allErrs, field.NotSupported(idxPath.Child("interface"), disk.Interface, []string{DiskInterfaceNVME, DiskInterfaceSCSI}))
}
Expand Down

0 comments on commit c97784b

Please sign in to comment.