Skip to content

Commit

Permalink
Merge pull request #1 from kahun/feature/azure_acr
Browse files Browse the repository at this point in the history
Add acr token support
  • Loading branch information
kahun authored May 23, 2023
2 parents b0d8a50 + c497b20 commit bebdfbc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 34 deletions.
32 changes: 32 additions & 0 deletions pkg/cluster/internal/create/actions/createworker/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ package createworker

import (
"bytes"
"context"
"encoding/base64"
b64 "encoding/base64"
"os"
"strings"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
Expand Down Expand Up @@ -158,3 +163,30 @@ func (b *AWSBuilder) getAzs() ([]string, error) {
}
return azs, nil
}

func getEcrToken(p commons.ProviderParams) (string, error) {
customProvider := credentials.NewStaticCredentialsProvider(
p.Credentials["AccessKey"], p.Credentials["SecretKey"], "",
)
cfg, err := config.LoadDefaultConfig(
context.TODO(),
config.WithCredentialsProvider(customProvider),
config.WithRegion(p.Region),
)
if err != nil {
return "", err
}

svc := ecr.NewFromConfig(cfg)
token, err := svc.GetAuthorizationToken(context.TODO(), &ecr.GetAuthorizationTokenInput{})
if err != nil {
return "", err
}
authData := token.AuthorizationData[0].AuthorizationToken
data, err := base64.StdEncoding.DecodeString(*authData)
if err != nil {
return "", err
}
parts := strings.SplitN(string(data), ":", 2)
return parts[1], nil
}
33 changes: 33 additions & 0 deletions pkg/cluster/internal/create/actions/createworker/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ package createworker

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v3"
Expand Down Expand Up @@ -129,3 +134,31 @@ func assignUserIdentity(i string, c string, r string, s map[string]string) error

return nil
}

func getAcrToken(p commons.ProviderParams, acrService string) (string, error) {
creds, err := azidentity.NewClientSecretCredential(
p.Credentials["TenantID"], p.Credentials["ClientID"], p.Credentials["ClientSecret"], nil,
)
if err != nil {
return "", err
}
ctx := context.Background()

aadToken, err := creds.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}})
if err != nil {
return "", err
}
formData := url.Values{
"grant_type": {"access_token"},
"service": {acrService},
"tenant": {p.Credentials["TenantID"]},
"access_token": {aadToken.Token},
}
jsonResponse, err := http.PostForm(fmt.Sprintf("https://%s/oauth2/exchange", acrService), formData)
if err != nil {
return "", err
}
var response map[string]interface{}
json.NewDecoder(jsonResponse.Body).Decode(&response)
return response["refresh_token"].(string), nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,20 @@ func (a *action) Execute(ctx *actions.ActionContext) error {
}

if registryType == "ecr" {
ecrToken, err := commons.GetEcrAuthToken(providerParams)
ecrToken, err := getEcrToken(providerParams)
if err != nil {
return errors.Wrap(err, "failed to get ECR auth token")
}
registryUser = "AWS"
registryPass = ecrToken
} else if registryType == "acr" {
acrService := strings.Split(registryUrl, "/")[0]
acrToken, err := getAcrToken(providerParams, acrService)
if err != nil {
return errors.Wrap(err, "failed to get ACR auth token")
}
registryUser = "00000000-0000-0000-0000-000000000000"
registryPass = acrToken
} else {
registryUser = keosRegistry["User"]
registryPass = keosRegistry["Pass"]
Expand Down
33 changes: 0 additions & 33 deletions pkg/commons/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,12 @@ package commons

import (
"bytes"
"context"
"encoding/base64"
"log"
"unicode"

"os"
"reflect"
"strings"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/fatih/structs"
"github.com/oleiade/reflections"
"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -396,30 +390,3 @@ func convertMapKeysToSnakeCase(m map[string]interface{}) map[string]interface{}
}
return newMap
}

func GetEcrAuthToken(p ProviderParams) (string, error) {
customProvider := credentials.NewStaticCredentialsProvider(
p.Credentials["AccessKey"], p.Credentials["SecretKey"], "",
)
cfg, err := config.LoadDefaultConfig(
context.TODO(),
config.WithCredentialsProvider(customProvider),
config.WithRegion(p.Region),
)
if err != nil {
panic("unable to load SDK config, " + err.Error())
}

svc := ecr.NewFromConfig(cfg)
token, err := svc.GetAuthorizationToken(context.TODO(), &ecr.GetAuthorizationTokenInput{})
if err != nil {
log.Fatal(err)
}
authData := token.AuthorizationData[0].AuthorizationToken
data, err := base64.StdEncoding.DecodeString(*authData)
if err != nil {
log.Fatal(err)
}
parts := strings.SplitN(string(data), ":", 2)
return parts[1], nil
}

0 comments on commit bebdfbc

Please sign in to comment.