Skip to content

Commit

Permalink
Implement credential for client auth
Browse files Browse the repository at this point in the history
Signed-off-by: Heba Elayoty <hebaelayoty@gmail.com>
  • Loading branch information
helayoty committed Jan 20, 2023
1 parent 85f6251 commit 9042a86
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 20 deletions.
35 changes: 34 additions & 1 deletion pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"strings"
"unicode/utf16"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/dimchansky/utfbom"
Expand All @@ -36,6 +38,38 @@ type Config struct {
Authorizer autorest.Authorizer
}

// GetMSICredential retrieve MSI credential
func (c *Config) GetMSICredential(ctx context.Context) (*azidentity.ManagedIdentityCredential, error) {
log.G(ctx).Debug("getting token using user identity")
opts := &azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(c.AuthConfig.UserIdentityClientId),
ClientOptions: azcore.ClientOptions{
Cloud: c.Cloud,
}}
msiCredential, err := azidentity.NewManagedIdentityCredential(opts)
if err != nil {
return nil, err
}

return msiCredential, nil
}

// GetSPCredential retrieve SP credential
func (c *Config) GetSPCredential(ctx context.Context) (*azidentity.ClientSecretCredential, error) {
log.G(ctx).Debug("getting token using service principal")
opts := &azidentity.ClientSecretCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: c.Cloud,
},
}
spCredential, err := azidentity.NewClientSecretCredential(c.AuthConfig.TenantID, c.AuthConfig.ClientID, c.AuthConfig.ClientSecret, opts)
if err != nil {
return nil, err
}

return spCredential, nil
}

// getAuthorizer return autorest authorizer.
func (c *Config) getAuthorizer(ctx context.Context, resource string) (autorest.Authorizer, error) {
var auth autorest.Authorizer
Expand Down Expand Up @@ -66,7 +100,6 @@ func (c *Config) getAuthorizer(ctx context.Context, resource string) (autorest.A
return nil, err
}
}

auth = autorest.NewBearerAuthorizer(token)
return auth, err
}
Expand Down
54 changes: 36 additions & 18 deletions pkg/client/client_apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
azaciv2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerinstance/armcontainerinstance/v2"
"github.com/pkg/errors"
"github.com/virtual-kubelet/azure-aci/pkg/auth"
Expand Down Expand Up @@ -39,32 +38,50 @@ type AzClientsAPIs struct {
}

func NewAzClientsAPIs(ctx context.Context, azConfig auth.Config) (*AzClientsAPIs, error) {
logger := log.G(ctx).WithField("method", "NewAzClientsAPIs")
ctx, span := trace.StartSpan(ctx, "client.NewAzClientsAPIs")
defer span.End()

obj := AzClientsAPIs{}
cred, err := azidentity.NewDefaultAzureCredential(nil)

logger.Debug("getting azure credential")

var err error
var credential azcore.TokenCredential
isUserIdentity := len(azConfig.AuthConfig.ClientID) == 0

if isUserIdentity {
credential, err = azConfig.GetMSICredential(ctx)
} else {
credential, err = azConfig.GetSPCredential(ctx)
}
if err != nil {
return nil, errors.Wrap(err, "an error has occurred while creating getting credential ")
}
ua := os.Getenv("ACI_EXTRA_USER_AGENT")

logger.Debug("setting aci user agent")
userAgent := os.Getenv("ACI_EXTRA_USER_AGENT")
options := arm.ClientOptions{
ClientOptions: azcore.ClientOptions{
Cloud: azConfig.Cloud,
Telemetry: policy.TelemetryOptions{
ApplicationID: ua,
ApplicationID: userAgent,
},
},
}

cClient, err := azaciv2.NewContainersClient(azConfig.AuthConfig.SubscriptionID, cred, &options)
logger.Debug("initializing aci clients")
cClient, err := azaciv2.NewContainersClient(azConfig.AuthConfig.SubscriptionID, credential, &options)
if err != nil {
return nil, errors.Wrap(err, "failed to create container client ")
}

cgClient, err := azaciv2.NewContainerGroupsClient(azConfig.AuthConfig.SubscriptionID, cred, &options)
cgClient, err := azaciv2.NewContainerGroupsClient(azConfig.AuthConfig.SubscriptionID, credential, &options)
if err != nil {
return nil, errors.Wrap(err, "failed to create container group client ")
}

lClient, err := azaciv2.NewLocationClient(azConfig.AuthConfig.SubscriptionID, cred, &options)
lClient, err := azaciv2.NewLocationClient(azConfig.AuthConfig.SubscriptionID, credential, &options)
if err != nil {
return nil, errors.Wrap(err, "failed to create location client ")
}
Expand All @@ -73,12 +90,13 @@ func NewAzClientsAPIs(ctx context.Context, azConfig auth.Config) (*AzClientsAPIs
obj.ContainerGroupClient = cgClient
obj.LocationClient = lClient

logger.Debug("aci clients have been initialized successfully")
return &obj, nil
}

func (a *AzClientsAPIs) GetContainerGroup(ctx context.Context, resourceGroup, containerGroupName string) (*azaciv2.ContainerGroup, error) {
_ = log.G(ctx).WithField("method", "GetContainerGroup")
ctx, span := trace.StartSpan(ctx, "aci.GetContainerGroup")
logger := log.G(ctx).WithField("method", "GetContainerGroup")
ctx, span := trace.StartSpan(ctx, "client.GetContainerGroup")
defer span.End()

var rawResponse *http.Response
Expand All @@ -87,8 +105,8 @@ func (a *AzClientsAPIs) GetContainerGroup(ctx context.Context, resourceGroup, co
result, err := a.ContainerGroupClient.Get(ctxWithResp, resourceGroup, containerGroupName, nil)
if err != nil {
if rawResponse.StatusCode == http.StatusNotFound {
return nil, errors.Wrapf(err, "failed to query Container Group %s, not found it", containerGroupName)

logger.Errorf("failed to query Container Group %s, not found", containerGroupName)
return nil, err
}
return nil, err
}
Expand All @@ -98,7 +116,7 @@ func (a *AzClientsAPIs) GetContainerGroup(ctx context.Context, resourceGroup, co

func (a *AzClientsAPIs) CreateContainerGroup(ctx context.Context, resourceGroup, podNS, podName string, cg *ContainerGroupWrapper) error {
logger := log.G(ctx).WithField("method", "CreateContainerGroup")
ctx, span := trace.StartSpan(ctx, "aci.CreateContainerGroup")
ctx, span := trace.StartSpan(ctx, "client.CreateContainerGroup")
defer span.End()

containerGroup := azaciv2.ContainerGroup{
Expand Down Expand Up @@ -128,7 +146,7 @@ func (a *AzClientsAPIs) CreateContainerGroup(ctx context.Context, resourceGroup,
// GetContainerGroupInfo returns a container group from ACI.
func (a *AzClientsAPIs) GetContainerGroupInfo(ctx context.Context, resourceGroup, namespace, name, nodeName string) (*azaciv2.ContainerGroup, error) {
logger := log.G(ctx).WithField("method", "GetContainerGroupInfo")
ctx, span := trace.StartSpan(ctx, "aci.GetContainerGroupInfo")
ctx, span := trace.StartSpan(ctx, "client.GetContainerGroupInfo")
defer span.End()
var rawResponse *http.Response
ctxWithResp := runtime.WithCaptureResponse(ctx, &rawResponse)
Expand All @@ -154,7 +172,7 @@ func (a *AzClientsAPIs) GetContainerGroupInfo(ctx context.Context, resourceGroup

func (a *AzClientsAPIs) GetContainerGroupListResult(ctx context.Context, resourceGroup string) ([]*azaciv2.ContainerGroup, error) {
logger := log.G(ctx).WithField("method", "GetContainerGroupListResult")
ctx, span := trace.StartSpan(ctx, "aci.GetContainerGroupListResult")
ctx, span := trace.StartSpan(ctx, "client.GetContainerGroupListResult")
defer span.End()

var rawResponse *http.Response
Expand All @@ -176,7 +194,7 @@ func (a *AzClientsAPIs) GetContainerGroupListResult(ctx context.Context, resourc

func (a *AzClientsAPIs) ListCapabilities(ctx context.Context, region string) ([]*azaciv2.Capabilities, error) {
logger := log.G(ctx).WithField("method", "ListCapabilities")
ctx, span := trace.StartSpan(ctx, "aci.ListCapabilities")
ctx, span := trace.StartSpan(ctx, "client.ListCapabilities")
defer span.End()

var rawResponse *http.Response
Expand All @@ -202,7 +220,7 @@ func (a *AzClientsAPIs) ListCapabilities(ctx context.Context, region string) ([]

func (a *AzClientsAPIs) DeleteContainerGroup(ctx context.Context, resourceGroup, cgName string) error {
logger := log.G(ctx).WithField("method", "DeleteContainerGroup")
ctx, span := trace.StartSpan(ctx, "aci.DeleteContainerGroup")
ctx, span := trace.StartSpan(ctx, "client.DeleteContainerGroup")
defer span.End()

var rawResponse *http.Response
Expand All @@ -220,7 +238,7 @@ func (a *AzClientsAPIs) DeleteContainerGroup(ctx context.Context, resourceGroup,

func (a *AzClientsAPIs) ListLogs(ctx context.Context, resourceGroup, cgName, containerName string, opts api.ContainerLogOpts) (*string, error) {
logger := log.G(ctx).WithField("method", "ListLogs")
ctx, span := trace.StartSpan(ctx, "aci.ListLogs")
ctx, span := trace.StartSpan(ctx, "client.ListLogs")
defer span.End()

var rawResponse *http.Response
Expand Down Expand Up @@ -266,7 +284,7 @@ func (a *AzClientsAPIs) ListLogs(ctx context.Context, resourceGroup, cgName, con

func (a *AzClientsAPIs) ExecuteContainerCommand(ctx context.Context, resourceGroup, cgName, containerName string, containerReq azaciv2.ContainerExecRequest) (*azaciv2.ContainerExecResponse, error) {
logger := log.G(ctx).WithField("method", "ExecuteContainerCommand")
ctx, span := trace.StartSpan(ctx, "aci.ExecuteContainerCommand")
ctx, span := trace.StartSpan(ctx, "client.ExecuteContainerCommand")
defer span.End()

var rawResponse *http.Response
Expand Down
5 changes: 4 additions & 1 deletion pkg/provider/aci.go
Original file line number Diff line number Diff line change
Expand Up @@ -902,13 +902,16 @@ func (p *ACIProvider) verifyContainer(container *v1.Container) error {

//this method is used for both initConainers and containers
func (p *ACIProvider) getCommand(container v1.Container) []*string {
var command, args []*string
command := make([]*string, len(container.Command))
for c := range container.Command {
command[c] = &container.Command[c]
}

args := make([]*string, len(container.Command))
for a := range container.Args {
args[a] = &container.Args[a]
}

return append(command, args...)
}

Expand Down

0 comments on commit 9042a86

Please sign in to comment.