Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add auth_config mock #528

Merged
merged 2 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 9 additions & 104 deletions pkg/auth/auth.go → pkg/auth/auth_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,17 @@ Licensed under the Apache 2.0 license.
package auth

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strings"
"unicode/utf16"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/dimchansky/utfbom"
"github.com/pkg/errors"
"github.com/virtual-kubelet/virtual-kubelet/log"
"github.com/virtual-kubelet/virtual-kubelet/trace"
Expand All @@ -35,6 +29,12 @@ const (
AzureChinaCloud CloudEnvironmentName = "AzureChinaCloud"
)

type ConfigInterface interface {
GetMSICredential(ctx context.Context) (*azidentity.ManagedIdentityCredential, error)
GetSPCredential(ctx context.Context) (*azidentity.ClientSecretCredential, error)
GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error)
}

type Config struct {
AKSCredential *aksCredential
AuthConfig *Authentication
Expand Down Expand Up @@ -74,8 +74,8 @@ func (c *Config) GetSPCredential(ctx context.Context) (*azidentity.ClientSecretC
return spCredential, nil
}

// getAuthorizer return autorest authorizer.
func (c *Config) getAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
// GetAuthorizer return autorest authorizer.
func (c *Config) GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
var auth autorest.Authorizer
var err error

Expand Down Expand Up @@ -186,109 +186,14 @@ func (c *Config) SetAuthConfig(ctx context.Context) error {

resource := c.Cloud.Services[cloud.ResourceManager].Endpoint

c.Authorizer, err = c.getAuthorizer(ctx, resource)
c.Authorizer, err = c.GetAuthorizer(ctx, resource)
if err != nil {
return err
}

return nil
}

// Authentication represents the Authentication file for Azure.
type Authentication struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
UserIdentityClientId string `json:"userIdentityClientId,omitempty"`
}

// newAuthenticationFromFile returns an Authentication struct from file path.
func (a *Authentication) newAuthenticationFromFile(filepath string) error {
b, err := ioutil.ReadFile(filepath)
if err != nil {
return fmt.Errorf("reading Authentication file %q failed: %v", filepath, err)
}

// Authentication file might be encoded.
decoded, err := a.decode(b)
if err != nil {
return fmt.Errorf("decoding Authentication file %q failed: %v", filepath, err)
}

// Unmarshal the Authentication file.
if err := json.Unmarshal(decoded, &a); err != nil {
return err
}
return nil
}

// NewAuthentication returns an Authentication struct from user provided credentials.
func NewAuthentication(clientID, clientSecret, subscriptionID, tenantID, userAssignedIdentityID string) *Authentication {
return &Authentication{
ClientID: clientID,
ClientSecret: clientSecret,
SubscriptionID: subscriptionID,
TenantID: tenantID,
UserIdentityClientId: userAssignedIdentityID,
}
}

// aksCredential represents the credential file for AKS
type aksCredential struct {
Cloud string `json:"cloud"`
TenantID string `json:"tenantId"`
SubscriptionID string `json:"subscriptionId"`
ClientID string `json:"aadClientId"`
ClientSecret string `json:"aadClientSecret"`
ResourceGroup string `json:"resourceGroup"`
Region string `json:"location"`
VNetName string `json:"vnetName"`
VNetResourceGroup string `json:"vnetResourceGroup"`
UserAssignedIdentityID string `json:"userAssignedIdentityID"`
}

// newAKSCredential returns an aksCredential struct from file path.
func newAKSCredential(ctx context.Context, filePath string) (*aksCredential, error) {
logger := log.G(ctx).WithField("method", "newAKSCredential").WithField("file", filePath)
logger.Debug("Reading AKS credential file")

b, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("reading AKS credential file %q failed: %v", filePath, err)
}

// Unmarshal the Authentication file.
var cred aksCredential
if err := json.Unmarshal(b, &cred); err != nil {
return nil, err
}
logger.Debug("load AKS credential file successfully")
return &cred, nil
}

func (a *Authentication) decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))

switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return ioutil.ReadAll(reader)
}

func getCloudConfiguration(cloudName string) cloud.Configuration {
switch cloudName {
case string(AzurePublicCloud):
Expand Down
11 changes: 5 additions & 6 deletions pkg/auth/auth_test.go → pkg/auth/auth_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package auth

import (
"context"
"io/ioutil"
"os"
"testing"

Expand All @@ -25,7 +24,7 @@ const cred = `
}`

func TestAKSCred(t *testing.T) {
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func TestAKSCred(t *testing.T) {
}

func TestAKSCredFileNotFound(t *testing.T) {
file, err := ioutil.TempFile("", "AKS_test")
file, err := os.CreateTemp("", "AKS_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -104,7 +103,7 @@ const credBad = `
"resourceGroup": "vk-test-rg",`

func TestAKSCredBadJson(t *testing.T) {
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand All @@ -129,7 +128,7 @@ func TestSetAuthConfigWithAuthFile(t *testing.T) {
"tenantId": "######-###-####-####-######"

}`
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -176,7 +175,7 @@ func TestSetAuthConfigWithAKSCredFile(t *testing.T) {
"vnetResourceGroup": "vk-aci-test-12917",
"userAssignedIdentityID": "######-tuhn-41af-re3e0-######"
}`
file, err := ioutil.TempFile("", "aks_auth_test")
file, err := os.CreateTemp("", "aks_auth_test")
if err != nil {
t.Error(err)
}
Expand Down
114 changes: 114 additions & 0 deletions pkg/auth/authentication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache 2.0 license.
*/
package auth

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"unicode/utf16"

"github.com/dimchansky/utfbom"
"github.com/virtual-kubelet/virtual-kubelet/log"
)

// Authentication represents the Authentication file for Azure.
type Authentication struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
UserIdentityClientId string `json:"userIdentityClientId,omitempty"`
}

// NewAuthentication returns an Authentication struct from user provided credentials.
func NewAuthentication(clientID, clientSecret, subscriptionID, tenantID, userAssignedIdentityID string) *Authentication {
return &Authentication{
ClientID: clientID,
ClientSecret: clientSecret,
SubscriptionID: subscriptionID,
TenantID: tenantID,
UserIdentityClientId: userAssignedIdentityID,
}
}

// newAuthenticationFromFile returns an Authentication struct from file path.
func (a *Authentication) newAuthenticationFromFile(filepath string) error {
b, err := os.ReadFile(filepath)
if err != nil {
return fmt.Errorf("reading Authentication file %q failed: %v", filepath, err)
}

// Authentication file might be encoded.
decoded, err := a.decode(b)
if err != nil {
return fmt.Errorf("decoding Authentication file %q failed: %v", filepath, err)
}

// Unmarshal the Authentication file.
if err := json.Unmarshal(decoded, &a); err != nil {
return err
}
return nil
}

// aksCredential represents the credential file for AKS
type aksCredential struct {
Cloud string `json:"cloud"`
TenantID string `json:"tenantId"`
SubscriptionID string `json:"subscriptionId"`
ClientID string `json:"aadClientId"`
ClientSecret string `json:"aadClientSecret"`
ResourceGroup string `json:"resourceGroup"`
Region string `json:"location"`
VNetName string `json:"vnetName"`
VNetResourceGroup string `json:"vnetResourceGroup"`
UserAssignedIdentityID string `json:"userAssignedIdentityID"`
}

// newAKSCredential returns an aksCredential struct from file path.
func newAKSCredential(ctx context.Context, filePath string) (*aksCredential, error) {
logger := log.G(ctx).WithField("method", "newAKSCredential").WithField("file", filePath)
logger.Debug("Reading AKS credential file")

b, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("reading AKS credential file %q failed: %v", filePath, err)
}

// Unmarshal the Authentication file.
var cred aksCredential
if err := json.Unmarshal(b, &cred); err != nil {
return nil, err
}
logger.Debug("load AKS credential file successfully")
return &cred, nil
}

func (a *Authentication) decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))

switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return io.ReadAll(reader)
}
43 changes: 43 additions & 0 deletions pkg/auth/mock_auth_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache 2.0 license.
*/
package auth

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
)

type GetMSICredentialFunc func(ctx context.Context) (*azidentity.ManagedIdentityCredential, error)
type GetSPCredentialFunc func(ctx context.Context) (*azidentity.ClientSecretCredential, error)
type GetAuthorizerFunc func(ctx context.Context, resource string) (autorest.Authorizer, error)

type MockConfig struct {
MockGetMSICredential GetMSICredentialFunc
MockGetSPCredential GetSPCredentialFunc
MockGetAuthorizer GetAuthorizerFunc
}

func (m *MockConfig) GetMSICredential(ctx context.Context) (*azidentity.ManagedIdentityCredential, error) {
if m.MockGetMSICredential != nil {
return m.MockGetMSICredential(ctx)
}
return nil, nil
}

func (m *MockConfig) GetSPCredential(ctx context.Context) (*azidentity.ClientSecretCredential, error) {
if m.MockGetSPCredential != nil {
return m.MockGetSPCredential(ctx)
}
return nil, nil
}

func (m *MockConfig) GetAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
if m.MockGetAuthorizer != nil {
return m.MockGetAuthorizer(ctx, resource)
}
return nil, nil
}