diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index d7bf0de057..4868cbbb24 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -145,6 +145,7 @@ const ( shareNamePrefixField = "sharenameprefix" requireInfraEncryptionField = "requireinfraencryption" enableMultichannelField = "enablemultichannel" + standard = "standard" premium = "premium" selectRandomMatchingAccountField = "selectrandommatchingaccount" accountQuotaField = "accountquota" diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index 08b5a6f097..47a4c026ef 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -338,12 +338,15 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } var vnetResourceIDs []string if fsType == nfs || protocol == nfs { - protocol = nfs - enableHTTPSTrafficOnly = false - if !strings.HasPrefix(strings.ToLower(sku), premium) { + if sku == "" { // NFS protocol only supports Premium storage sku = string(storage.SkuNamePremiumLRS) + } else if strings.HasPrefix(strings.ToLower(sku), standard) { + return nil, status.Errorf(codes.InvalidArgument, "nfs protocol only supports premium storage, current account type: %s", sku) } + + protocol = nfs + enableHTTPSTrafficOnly = false shareProtocol = storage.EnabledProtocolsNFS // NFS protocol does not need account key storeAccountKey = false diff --git a/pkg/azurefile/controllerserver_test.go b/pkg/azurefile/controllerserver_test.go index 340777c5ce..5f94c8e1cc 100644 --- a/pkg/azurefile/controllerserver_test.go +++ b/pkg/azurefile/controllerserver_test.go @@ -265,6 +265,29 @@ func TestCreateVolume(t *testing.T) { } }, }, + { + name: "nfs protocol only supports premium storage", + testFunc: func(t *testing.T) { + allParam := map[string]string{ + protocolField: "nfs", + skuNameField: "Standard_LRS", + } + + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name-nfs-protocol-standard-sku", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: allParam, + } + + d := NewFakeDriver() + expectedErr := status.Errorf(codes.InvalidArgument, "nfs protocol only supports premium storage, current account type: Standard_LRS") + _, err := d.CreateVolume(ctx, req) + if !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + }, + }, { name: "Invalid accessTier", testFunc: func(t *testing.T) {