Skip to content

Commit

Permalink
Fix MSI token refresh (#4611)
Browse files Browse the repository at this point in the history
Signed-off-by: Philip Laine <philip.laine@xenit.se>
  • Loading branch information
Philip Laine authored Sep 1, 2021
1 parent 082c1e8 commit 5bda2d4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pkg/objstore/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ func NewBucket(logger log.Logger, azureConfig []byte, component string) (*Bucket
}

ctx := context.Background()
container, err := createContainer(ctx, conf)
container, err := createContainer(ctx, logger, conf)
if err != nil {
ret, ok := err.(blob.StorageError)
if !ok {
return nil, errors.Wrapf(err, "Azure API return unexpected error: %T\n", err)
}
if ret.ServiceCode() == "ContainerAlreadyExists" {
level.Debug(logger).Log("msg", "Getting connection to existing Azure blob container", "container", conf.ContainerName)
container, err = getContainer(ctx, conf)
container, err = getContainer(ctx, logger, conf)
if err != nil {
return nil, errors.Wrapf(err, "cannot get existing Azure blob container: %s", container)
}
Expand Down
36 changes: 21 additions & 15 deletions pkg/objstore/azure/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/Azure/azure-pipeline-go/pipeline"
blob "github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
)

// DirDelim is the delimiter used to model a directory structure in an object store bucket.
Expand All @@ -34,24 +36,29 @@ func init() {
pipeline.SetForceLogEnabled(false)
}

func getAzureStorageCredentials(conf Config) (blob.Credential, error) {
func getAzureStorageCredentials(logger log.Logger, conf Config) (blob.Credential, error) {
if conf.MSIResource != "" {
msiConfig := auth.NewMSIConfig()
msiConfig.Resource = conf.MSIResource

azureServicePrincipalToken, err := msiConfig.ServicePrincipalToken()
spt, err := msiConfig.ServicePrincipalToken()
if err != nil {
return nil, err
}

// Get a new token.
err = azureServicePrincipalToken.Refresh()
if err != nil {
if err := spt.Refresh(); err != nil {
return nil, err
}
token := azureServicePrincipalToken.Token()

return blob.NewTokenCredential(token.AccessToken, nil), nil
return blob.NewTokenCredential(spt.Token().AccessToken, func(tc blob.TokenCredential) time.Duration {
err := spt.Refresh()
if err != nil {
level.Error(logger).Log("msg", "could not refresh MSI token", "err", err)
// Retry later as the error can be related to API throttling
return 30 * time.Second
}
tc.SetToken(spt.Token().AccessToken)
return spt.Token().Expires().Sub(time.Now().Add(2 * time.Minute))
}), nil
}

credential, err := blob.NewSharedKeyCredential(conf.StorageAccountName, conf.StorageAccountKey)
Expand All @@ -61,9 +68,8 @@ func getAzureStorageCredentials(conf Config) (blob.Credential, error) {
return credential, nil
}

func getContainerURL(ctx context.Context, conf Config) (blob.ContainerURL, error) {

credentials, err := getAzureStorageCredentials(conf)
func getContainerURL(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
credentials, err := getAzureStorageCredentials(logger, conf)

if err != nil {
return blob.ContainerURL{}, err
Expand Down Expand Up @@ -134,8 +140,8 @@ func DefaultTransport(config Config) *http.Transport {
}
}

func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, conf)
func getContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, logger, conf)
if err != nil {
return blob.ContainerURL{}, err
}
Expand All @@ -144,8 +150,8 @@ func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
return c, err
}

func createContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, conf)
func createContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, logger, conf)
if err != nil {
return blob.ContainerURL{}, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/objstore/azure/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"testing"

"github.com/go-kit/kit/log"
"github.com/thanos-io/thanos/pkg/testutil"
)

Expand Down Expand Up @@ -50,7 +51,7 @@ func Test_getContainerURL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
got, err := getContainerURL(ctx, tt.args.conf)
got, err := getContainerURL(ctx, log.NewNopLogger(), tt.args.conf)
if (err != nil) != tt.wantErr {
t.Errorf("getContainerURL() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down

0 comments on commit 5bda2d4

Please sign in to comment.