diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index f46b7acd..9a9c9ee1 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -21,6 +21,9 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" + "github.com/kylelemons/godebug/pretty" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" @@ -28,8 +31,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" - "github.com/golang-jwt/jwt/v5" - "github.com/kylelemons/godebug/pretty" ) // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request @@ -1405,3 +1406,59 @@ func TestWithAuthenticationScheme(t *testing.T) { t.Fatalf(`unexpected access token "%s"`, result.AccessToken) } } + +func TestAcquireTokenByCredentialFromDSTS(t *testing.T) { + tests := map[string]struct { + cred string + }{ + "secret": {cred: "fake_secret"}, + "signed assertion": {cred: "fake_assertion"}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cred, err := NewCredFromSecret(test.cred) + if err != nil { + t.Fatal(err) + } + client, err := fakeClient(accesstokens.TokenResponse{ + AccessToken: token, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "Bearer", + }, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant) + if err != nil { + t.Fatal(err) + } + + // expect first attempt to fail + _, err = client.AcquireTokenSilent(context.Background(), tokenScope) + if err == nil { + t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err) + } + + tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope) + if err != nil { + t.Errorf("got err == %s, want err == nil", err) + } + if tk.AccessToken != token { + t.Errorf("unexpected access token %s", tk.AccessToken) + } + + tk, err = client.AcquireTokenSilent(context.Background(), tokenScope) + if err != nil { + t.Errorf("got err == %s, want err == nil", err) + } + if tk.AccessToken != token { + t.Errorf("unexpected access token %s", tk.AccessToken) + } + + // fail for another tenant + tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other")) + if err == nil { + t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err) + } + }) + } +} diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index 5dd9fe08..e0653134 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -10,6 +10,8 @@ import ( "io" "time" + "github.com/google/uuid" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" @@ -18,7 +20,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" - "github.com/google/uuid" ) // ResolveEndpointer contains the methods for resolving authority endpoints. diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index a9a70186..a49e0357 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -136,8 +136,12 @@ const ( const ( AAD = "MSSTS" ADFS = "ADFS" + DSTS = "DSTS" ) +// DSTSTenant is referenced throughout multiple files, let us use a const in case we ever need to change it. +const DSTSTenant = "7a433bfc-2514-4697-b467-e0933190487f" + // AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens. type AuthenticationScheme interface { // Extra parameters that are added to the request to the /token endpoint. @@ -251,6 +255,8 @@ func (p AuthParams) WithTenant(ID string) (AuthParams, error) { authority = "https://" + path.Join(p.AuthorityInfo.Host, ID) case ADFS: return p, errors.New("ADFS authority doesn't support tenants") + case DSTS: + return p, errors.New("dSTS authority doesn't support tenants") } info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled) @@ -350,35 +356,43 @@ type Info struct { InstanceDiscoveryDisabled bool } -func firstPathSegment(u *url.URL) (string, error) { - pathParts := strings.Split(u.EscapedPath(), "/") - if len(pathParts) >= 2 { - return pathParts[1], nil - } - - return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) -} - // NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided. func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { u, err := url.Parse(strings.ToLower(authority)) - if err != nil || u.Scheme != "https" { - return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) + if err != nil { + return Info{}, fmt.Errorf("couldn't parse authority url: %w", err) + } + if u.Scheme != "https" { + return Info{}, errors.New("authority url scheme must be https") } - tenant, err := firstPathSegment(u) - if err != nil { - return Info{}, err + pathParts := strings.Split(u.EscapedPath(), "/") + if len(pathParts) < 2 { + return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/"`) } - authorityType := AAD - if tenant == "adfs" { + + var authorityType, tenant string + switch pathParts[1] { + case "adfs": authorityType = ADFS + case "dstsv2": + if len(pathParts) != 3 { + return Info{}, fmt.Errorf("dSTS authority must be an https URL such as https:///dstsv2/%s", DSTSTenant) + } + if pathParts[2] != DSTSTenant { + return Info{}, fmt.Errorf("dSTS authority only accepts a single tenant %q", DSTSTenant) + } + authorityType = DSTS + tenant = DSTSTenant + default: + authorityType = AAD + tenant = pathParts[1] } // u.Host includes the port, if any, which is required for private cloud deployments return Info{ Host: u.Host, - CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant), + CanonicalAuthorityURI: authority, AuthorityType: authorityType, ValidateAuthority: validateAuthority, Tenant: tenant, diff --git a/apps/internal/oauth/ops/authority/authority_test.go b/apps/internal/oauth/ops/authority/authority_test.go index 6071ed4f..6795a8f1 100644 --- a/apps/internal/oauth/ops/authority/authority_test.go +++ b/apps/internal/oauth/ops/authority/authority_test.go @@ -341,6 +341,7 @@ func TestAuthParamsWithTenant(t *testing.T) { "tenant can't be consumers for AAD": {authority: host + uuid1, tenant: "consumers", expectError: true}, "tenant can't be organizations for AAD": {authority: host + uuid1, tenant: "organizations", expectError: true}, "can't override tenant for ADFS ever": {authority: host + "adfs", tenant: uuid1, expectError: true}, + "can't override tenant for dSTS ever": {authority: host + "dstsv2/" + DSTSTenant, tenant: uuid1, expectError: true}, "can't override AAD tenant consumers": {authority: host + "consumers", tenant: uuid1, expectError: true}, } diff --git a/apps/internal/oauth/resolvers.go b/apps/internal/oauth/resolvers.go index 5a1aa9a7..4030ec8d 100644 --- a/apps/internal/oauth/resolvers.go +++ b/apps/internal/oauth/resolvers.go @@ -48,7 +48,7 @@ func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo return endpoints, nil } - endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName) + endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo) if err != nil { return authority.Endpoints{}, err } @@ -116,9 +116,12 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry } -func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) { - if authorityInfo.Tenant == "adfs" { +func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) { + if authorityInfo.AuthorityType == authority.ADFS { return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil + } else if authorityInfo.AuthorityType == authority.DSTS { + return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil + } else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) { resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo) if err != nil { @@ -131,7 +134,6 @@ func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, aut return "", err } return resp.TenantDiscoveryEndpoint, nil - } return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil