diff --git a/sdk/azidentity/syncer.go b/sdk/azidentity/syncer.go index ae38555994b0..7de25838e01e 100644 --- a/sdk/azidentity/syncer.go +++ b/sdk/azidentity/syncer.go @@ -23,8 +23,7 @@ type authFn func(context.Context, policy.TokenRequestOptions) (azcore.AccessToke // syncer synchronizes authentication calls so that goroutines can share a credential instance type syncer struct { addlTenants []string - authing bool - cond *sync.Cond + mu *sync.Mutex reqToken, silent authFn name, tenant string } @@ -32,7 +31,7 @@ type syncer struct { func newSyncer(name, tenant string, additionalTenants []string, reqToken, silentAuth authFn) *syncer { return &syncer{ addlTenants: resolveAdditionalTenants(additionalTenants), - cond: &sync.Cond{L: &sync.Mutex{}}, + mu: &sync.Mutex{}, name: name, reqToken: reqToken, silent: silentAuth, @@ -42,40 +41,23 @@ func newSyncer(name, tenant string, additionalTenants []string, reqToken, silent // GetToken ensures that only one goroutine authenticates at a time func (s *syncer) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - var at azcore.AccessToken - var err error + s.mu.Lock() + defer s.mu.Unlock() if len(opts.Scopes) == 0 { - return at, errors.New(s.name + ".GetToken() requires at least one scope") + return azcore.AccessToken{}, errors.New(s.name + ".GetToken() requires at least one scope") } // we don't resolve the tenant for managed identities because they can acquire tokens only from their home tenants if s.name != credNameManagedIdentity { tenant, err := s.resolveTenant(opts.TenantID) if err != nil { - return at, err + return azcore.AccessToken{}, err } opts.TenantID = tenant } - auth := false - s.cond.L.Lock() - defer s.cond.L.Unlock() - for { - at, err = s.silent(ctx, opts) - if err == nil { - // got a token - break - } - if !s.authing { - // this goroutine will request a token - s.authing, auth = true, true - break - } - // another goroutine is acquiring a token; wait for it to finish, then try silent auth again - s.cond.Wait() - } - if auth { - s.authing = false + at, err := s.silent(ctx, opts) + if err != nil { + // cache miss; request a new token at, err = s.reqToken(ctx, opts) - s.cond.Broadcast() } if err != nil { // Return credentialUnavailableError directly because that type affects the behavior of credential chains.