Skip to content

Commit

Permalink
feat(oauth): add support for dSTS authority type
Browse files Browse the repository at this point in the history
  • Loading branch information
handsomejack-42 authored and bgavrilMS committed Nov 5, 2024
1 parent 133b78f commit 06d3fb2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 24 deletions.
61 changes: 59 additions & 2 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
Expand Down Expand Up @@ -1405,3 +1406,59 @@ func TestWithAuthenticationScheme(t *testing.T) {
t.Fatalf(`unexpected access token "%s"`, result.AccessToken)
}
}

func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
tests := map[string]struct {
cred string
}{
"secret": {cred: "fake_secret"},
"signed assertion": {cred: "fake_assertion"},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
cred, err := NewCredFromSecret(test.cred)
if err != nil {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant)
if err != nil {
t.Fatal(err)
}

// expect first attempt to fail
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}

tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

// fail for another tenant
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other"))
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}
})
}
}
3 changes: 2 additions & 1 deletion apps/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"io"
"time"

"github.com/google/uuid"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
Expand All @@ -18,7 +20,6 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
"github.com/google/uuid"
)

// ResolveEndpointer contains the methods for resolving authority endpoints.
Expand Down
48 changes: 31 additions & 17 deletions apps/internal/oauth/ops/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ const (
const (
AAD = "MSSTS"
ADFS = "ADFS"
DSTS = "DSTS"
)

// DSTSTenant is referenced throughout multiple files, let us use a const in case we ever need to change it.
const DSTSTenant = "7a433bfc-2514-4697-b467-e0933190487f"

// AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens.
type AuthenticationScheme interface {
// Extra parameters that are added to the request to the /token endpoint.
Expand Down Expand Up @@ -251,6 +255,8 @@ func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
authority = "https://" + path.Join(p.AuthorityInfo.Host, ID)
case ADFS:
return p, errors.New("ADFS authority doesn't support tenants")
case DSTS:
return p, errors.New("dSTS authority doesn't support tenants")
}

info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
Expand Down Expand Up @@ -350,35 +356,43 @@ type Info struct {
InstanceDiscoveryDisabled bool
}

func firstPathSegment(u *url.URL) (string, error) {
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) >= 2 {
return pathParts[1], nil
}

return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
}

// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
u, err := url.Parse(strings.ToLower(authority))
if err != nil || u.Scheme != "https" {
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
if err != nil {
return Info{}, fmt.Errorf("couldn't parse authority url: %w", err)
}
if u.Scheme != "https" {
return Info{}, errors.New("authority url scheme must be https")
}

tenant, err := firstPathSegment(u)
if err != nil {
return Info{}, err
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) < 2 {
return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/<your tenant>"`)
}
authorityType := AAD
if tenant == "adfs" {

var authorityType, tenant string
switch pathParts[1] {
case "adfs":
authorityType = ADFS
case "dstsv2":
if len(pathParts) != 3 {
return Info{}, fmt.Errorf("dSTS authority must be an https URL such as https://<authority>/dstsv2/%s", DSTSTenant)
}
if pathParts[2] != DSTSTenant {
return Info{}, fmt.Errorf("dSTS authority only accepts a single tenant %q", DSTSTenant)
}
authorityType = DSTS
tenant = DSTSTenant
default:
authorityType = AAD
tenant = pathParts[1]
}

// u.Host includes the port, if any, which is required for private cloud deployments
return Info{
Host: u.Host,
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
CanonicalAuthorityURI: authority,

This comment has been minimized.

Copy link
@ekindingwcar

ekindingwcar Nov 6, 2024

The difference here, causing the issue mentioned below, seems to be that the CanonicalAuthorityURI previously ended with a / but no longer does (not guaranteed, at least).

This comment has been minimized.

Copy link
@bgavrilMS

bgavrilMS Nov 6, 2024

Member

Fixing it now

AuthorityType: authorityType,
ValidateAuthority: validateAuthority,
Tenant: tenant,
Expand Down
1 change: 1 addition & 0 deletions apps/internal/oauth/ops/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ func TestAuthParamsWithTenant(t *testing.T) {
"tenant can't be consumers for AAD": {authority: host + uuid1, tenant: "consumers", expectError: true},
"tenant can't be organizations for AAD": {authority: host + uuid1, tenant: "organizations", expectError: true},
"can't override tenant for ADFS ever": {authority: host + "adfs", tenant: uuid1, expectError: true},
"can't override tenant for dSTS ever": {authority: host + "dstsv2/" + DSTSTenant, tenant: uuid1, expectError: true},
"can't override AAD tenant consumers": {authority: host + "consumers", tenant: uuid1, expectError: true},
}

Expand Down
10 changes: 6 additions & 4 deletions apps/internal/oauth/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo
return endpoints, nil
}

endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo)
if err != nil {
return authority.Endpoints{}, err
}
Expand Down Expand Up @@ -116,9 +116,12 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
}

func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
if authorityInfo.Tenant == "adfs" {
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) {
if authorityInfo.AuthorityType == authority.ADFS {
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
} else if authorityInfo.AuthorityType == authority.DSTS {
return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil

} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
Expand All @@ -131,7 +134,6 @@ func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, aut
return "", err
}
return resp.TenantDiscoveryEndpoint, nil

}

return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil

This comment has been minimized.

Copy link
@ekindingwcar

ekindingwcar Nov 6, 2024

Missing a / here, we noticed while mistakenly bumping to 1.3.0.

Expand Down

0 comments on commit 06d3fb2

Please sign in to comment.