Skip to content

Commit

Permalink
refactor: Add auth_config mock (#528)
Browse files Browse the repository at this point in the history
Signed-off-by: Heba Elayoty <hebaelayoty@gmail.com>
  • Loading branch information
helayoty authored Apr 11, 2023
1 parent 7e523b0 commit 47ee054
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 110 deletions.
113 changes: 9 additions & 104 deletions pkg/auth/auth.go → pkg/auth/auth_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,17 @@ Licensed under the Apache 2.0 license.
package auth

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strings"
"unicode/utf16"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/dimchansky/utfbom"
"github.com/pkg/errors"
"github.com/virtual-kubelet/virtual-kubelet/log"
"github.com/virtual-kubelet/virtual-kubelet/trace"
Expand All @@ -35,6 +29,12 @@ const (
AzureChinaCloud CloudEnvironmentName = "AzureChinaCloud"
)

type ConfigInterface interface {
GetMSICredential(ctx context.Context) (*azidentity.ManagedIdentityCredential, error)
GetSPCredential(ctx context.Context) (*azidentity.ClientSecretCredential, error)
GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error)
}

type Config struct {
AKSCredential *aksCredential
AuthConfig *Authentication
Expand Down Expand Up @@ -74,8 +74,8 @@ func (c *Config) GetSPCredential(ctx context.Context) (*azidentity.ClientSecretC
return spCredential, nil
}

// getAuthorizer return autorest authorizer.
func (c *Config) getAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
// GetAuthorizer return autorest authorizer.
func (c *Config) GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
var auth autorest.Authorizer
var err error

Expand Down Expand Up @@ -186,109 +186,14 @@ func (c *Config) SetAuthConfig(ctx context.Context) error {

resource := c.Cloud.Services[cloud.ResourceManager].Endpoint

c.Authorizer, err = c.getAuthorizer(ctx, resource)
c.Authorizer, err = c.GetAuthorizer(ctx, resource)
if err != nil {
return err
}

return nil
}

// Authentication represents the Authentication file for Azure.
type Authentication struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
UserIdentityClientId string `json:"userIdentityClientId,omitempty"`
}

// newAuthenticationFromFile returns an Authentication struct from file path.
func (a *Authentication) newAuthenticationFromFile(filepath string) error {
b, err := ioutil.ReadFile(filepath)
if err != nil {
return fmt.Errorf("reading Authentication file %q failed: %v", filepath, err)
}

// Authentication file might be encoded.
decoded, err := a.decode(b)
if err != nil {
return fmt.Errorf("decoding Authentication file %q failed: %v", filepath, err)
}

// Unmarshal the Authentication file.
if err := json.Unmarshal(decoded, &a); err != nil {
return err
}
return nil
}

// NewAuthentication returns an Authentication struct from user provided credentials.
func NewAuthentication(clientID, clientSecret, subscriptionID, tenantID, userAssignedIdentityID string) *Authentication {
return &Authentication{
ClientID: clientID,
ClientSecret: clientSecret,
SubscriptionID: subscriptionID,
TenantID: tenantID,
UserIdentityClientId: userAssignedIdentityID,
}
}

// aksCredential represents the credential file for AKS
type aksCredential struct {
Cloud string `json:"cloud"`
TenantID string `json:"tenantId"`
SubscriptionID string `json:"subscriptionId"`
ClientID string `json:"aadClientId"`
ClientSecret string `json:"aadClientSecret"`
ResourceGroup string `json:"resourceGroup"`
Region string `json:"location"`
VNetName string `json:"vnetName"`
VNetResourceGroup string `json:"vnetResourceGroup"`
UserAssignedIdentityID string `json:"userAssignedIdentityID"`
}

// newAKSCredential returns an aksCredential struct from file path.
func newAKSCredential(ctx context.Context, filePath string) (*aksCredential, error) {
logger := log.G(ctx).WithField("method", "newAKSCredential").WithField("file", filePath)
logger.Debug("Reading AKS credential file")

b, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("reading AKS credential file %q failed: %v", filePath, err)
}

// Unmarshal the Authentication file.
var cred aksCredential
if err := json.Unmarshal(b, &cred); err != nil {
return nil, err
}
logger.Debug("load AKS credential file successfully")
return &cred, nil
}

func (a *Authentication) decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))

switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return ioutil.ReadAll(reader)
}

func getCloudConfiguration(cloudName string) cloud.Configuration {
switch cloudName {
case string(AzurePublicCloud):
Expand Down
11 changes: 5 additions & 6 deletions pkg/auth/auth_test.go → pkg/auth/auth_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package auth

import (
"context"
"io/ioutil"
"os"
"testing"

Expand All @@ -25,7 +24,7 @@ const cred = `
}`

func TestAKSCred(t *testing.T) {
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func TestAKSCred(t *testing.T) {
}

func TestAKSCredFileNotFound(t *testing.T) {
file, err := ioutil.TempFile("", "AKS_test")
file, err := os.CreateTemp("", "AKS_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -104,7 +103,7 @@ const credBad = `
"resourceGroup": "vk-test-rg",`

func TestAKSCredBadJson(t *testing.T) {
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand All @@ -129,7 +128,7 @@ func TestSetAuthConfigWithAuthFile(t *testing.T) {
"tenantId": "######-###-####-####-######"
}`
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -176,7 +175,7 @@ func TestSetAuthConfigWithAKSCredFile(t *testing.T) {
"vnetResourceGroup": "vk-aci-test-12917",
"userAssignedIdentityID": "######-tuhn-41af-re3e0-######"
}`
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down
114 changes: 114 additions & 0 deletions pkg/auth/authentication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache 2.0 license.
*/
package auth

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"unicode/utf16"

"github.com/dimchansky/utfbom"
"github.com/virtual-kubelet/virtual-kubelet/log"
)

// Authentication represents the Authentication file for Azure.
type Authentication struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
UserIdentityClientId string `json:"userIdentityClientId,omitempty"`
}

// NewAuthentication returns an Authentication struct from user provided credentials.
func NewAuthentication(clientID, clientSecret, subscriptionID, tenantID, userAssignedIdentityID string) *Authentication {
return &Authentication{
ClientID: clientID,
ClientSecret: clientSecret,
SubscriptionID: subscriptionID,
TenantID: tenantID,
UserIdentityClientId: userAssignedIdentityID,
}
}

// newAuthenticationFromFile returns an Authentication struct from file path.
func (a *Authentication) newAuthenticationFromFile(filepath string) error {
b, err := os.ReadFile(filepath)
if err != nil {
return fmt.Errorf("reading Authentication file %q failed: %v", filepath, err)
}

// Authentication file might be encoded.
decoded, err := a.decode(b)
if err != nil {
return fmt.Errorf("decoding Authentication file %q failed: %v", filepath, err)
}

// Unmarshal the Authentication file.
if err := json.Unmarshal(decoded, &a); err != nil {
return err
}
return nil
}

// aksCredential represents the credential file for AKS
type aksCredential struct {
Cloud string `json:"cloud"`
TenantID string `json:"tenantId"`
SubscriptionID string `json:"subscriptionId"`
ClientID string `json:"aadClientId"`
ClientSecret string `json:"aadClientSecret"`
ResourceGroup string `json:"resourceGroup"`
Region string `json:"location"`
VNetName string `json:"vnetName"`
VNetResourceGroup string `json:"vnetResourceGroup"`
UserAssignedIdentityID string `json:"userAssignedIdentityID"`
}

// newAKSCredential returns an aksCredential struct from file path.
func newAKSCredential(ctx context.Context, filePath string) (*aksCredential, error) {
logger := log.G(ctx).WithField("method", "newAKSCredential").WithField("file", filePath)
logger.Debug("Reading AKS credential file")

b, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("reading AKS credential file %q failed: %v", filePath, err)
}

// Unmarshal the Authentication file.
var cred aksCredential
if err := json.Unmarshal(b, &cred); err != nil {
return nil, err
}
logger.Debug("load AKS credential file successfully")
return &cred, nil
}

func (a *Authentication) decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))

switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return io.ReadAll(reader)
}
43 changes: 43 additions & 0 deletions pkg/auth/mock_auth_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache 2.0 license.
*/
package auth

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
)

type GetMSICredentialFunc func(ctx context.Context) (*azidentity.ManagedIdentityCredential, error)
type GetSPCredentialFunc func(ctx context.Context) (*azidentity.ClientSecretCredential, error)
type GetAuthorizerFunc func(ctx context.Context, resource string) (autorest.Authorizer, error)

type MockConfig struct {
MockGetMSICredential GetMSICredentialFunc
MockGetSPCredential GetSPCredentialFunc
MockGetAuthorizer GetAuthorizerFunc
}

func (m *MockConfig) GetMSICredential(ctx context.Context) (*azidentity.ManagedIdentityCredential, error) {
if m.MockGetMSICredential != nil {
return m.MockGetMSICredential(ctx)
}
return nil, nil
}

func (m *MockConfig) GetSPCredential(ctx context.Context) (*azidentity.ClientSecretCredential, error) {
if m.MockGetSPCredential != nil {
return m.MockGetSPCredential(ctx)
}
return nil, nil
}

func (m *MockConfig) GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
if m.MockGetAuthorizer != nil {
return m.MockGetAuthorizer(ctx, resource)
}
return nil, nil
}

0 comments on commit 47ee054

Please sign in to comment.