diff --git a/internal/api/provider/azure.go b/internal/api/provider/azure.go index 915509d3d..4a341f4d6 100644 --- a/internal/api/provider/azure.go +++ b/internal/api/provider/azure.go @@ -5,15 +5,10 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io" - "net/http" - "net/url" "regexp" "strings" - "unicode/utf8" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/conf" "golang.org/x/oauth2" ) @@ -167,208 +162,3 @@ func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use return nil, fmt.Errorf("azure: no OIDC ID token present in response") } - -type AzureIDTokenClaimSource struct { - Endpoint string `json:"endpoint"` -} - -type AzureIDTokenClaims struct { - jwt.RegisteredClaims - - Email string `json:"email"` - Name string `json:"name"` - PreferredUsername string `json:"preferred_username"` - XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"` - - ClaimNames map[string]string `json:"_claim_names"` - ClaimSources map[string]AzureIDTokenClaimSource `json:"_claim_sources"` -} - -// ResolveIndirectClaims resolves claims in the Azure Token that require a call to the Microsoft Graph API. This is typically to an API like this: https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects?view=graph-rest-1.0&tabs=http -func (c *AzureIDTokenClaims) ResolveIndirectClaims(ctx context.Context, httpClient *http.Client, accessToken string) (map[string]any, error) { - if len(c.ClaimNames) == 0 || len(c.ClaimSources) == 0 { - return nil, nil - } - - result := make(map[string]any) - - for claimName, claimSource := range c.ClaimNames { - claimEndpointObject, ok := c.ClaimSources[claimSource] - - if !ok || !strings.HasPrefix(claimEndpointObject.Endpoint, "https://") { - continue - } - - u, err := url.ParseRequestURI(claimEndpointObject.Endpoint) - if err != nil { - return nil, fmt.Errorf("azure: failed to parse endpoint URL %q (resolving overage claim %q): %w", claimEndpointObject.Endpoint, claimName, err) - } - - queryParams := u.Query() - if !queryParams.Has("api-version") { - // https://stackoverflow.com/questions/51085863/retrieve-group-claims-using-claim-sources-returns-the-specified-api-version-is - queryParams.Add("api-version", "1.6") - u.RawQuery = queryParams.Encode() - } - - claimEndpoint := u.String() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, claimEndpoint, strings.NewReader(`{"securityEnabledOnly":true}`)) - if err != nil { - return nil, fmt.Errorf("azure: failed to create POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err) - } - - req.Header.Add("Authorization", "Bearer "+accessToken) - req.Header.Add("Content-Type", "application/json") - - resp, err := httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("azure: failed to send POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - resBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024)) - - body := "" - if len(resBody) > 0 { - if utf8.Valid(resBody) { - body = string(resBody) - } else { - body = "" - } - } - - readErrString := "" - if readErr != nil { - readErrString = fmt.Sprintf(" with read error %q", readErr.Error()) - } - - return nil, fmt.Errorf("azure: received %d but expected 200 HTTP status code when sending POST to %q (resolving overage claim %q) with response body %q%s", resp.StatusCode, claimEndpoint, claimName, body, readErrString) - } - - var responseResult struct { - Value any `json:"value"` - } - - if err := json.NewDecoder(resp.Body).Decode(&responseResult); err != nil { - return nil, fmt.Errorf("azure: failed to parse JSON response from POST to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err) - } - - result[claimName] = responseResult.Value - } - - return result, nil -} - -func (c *AzureIDTokenClaims) IsEmailVerified() bool { - emailVerified := false - - edov := c.XMicrosoftEmailDomainOwnerVerified - - // If xms_edov is not set, and an email is present or xms_edov is true, - // only then is the email regarded as verified. - // https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users - if edov == nil { - // An email is provided, but xms_edov is not -- probably not - // configured, so we must assume the email is verified as Azure - // will only send out a potentially unverified email address in - // single-tenanat apps. - emailVerified = c.Email != "" - } else { - edovBool := false - - // Azure can't be trusted with how they encode the xms_edov - // claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true. - switch v := edov.(type) { - case bool: - edovBool = v - - case string: - edovBool = v == "1" || v == "true" - - default: - edovBool = false - } - - emailVerified = c.Email != "" && edovBool - } - - return emailVerified -} - -// removeAzureClaimsFromCustomClaims contains the list of claims to be removed -// from the CustomClaims map. See: -// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference -var removeAzureClaimsFromCustomClaims = []string{ - "aud", - "iss", - "iat", - "nbf", - "exp", - "c_hash", - "at_hash", - "aio", - "nonce", - "rh", - "uti", - "jti", - "ver", - "sub", - "name", - "preferred_username", -} - -func parseAzureIDToken(ctx context.Context, token *oidc.IDToken, accessToken string) (*oidc.IDToken, *UserProvidedData, error) { - var data UserProvidedData - - var azureClaims AzureIDTokenClaims - if err := token.Claims(&azureClaims); err != nil { - return nil, nil, err - } - - data.Metadata = &Claims{ - Issuer: token.Issuer, - Subject: token.Subject, - ProviderId: token.Subject, - PreferredUsername: azureClaims.PreferredUsername, - FullName: azureClaims.Name, - CustomClaims: make(map[string]any), - } - - if azureClaims.Email != "" { - data.Emails = []Email{{ - Email: azureClaims.Email, - Verified: azureClaims.IsEmailVerified(), - Primary: true, - }} - } - - if err := token.Claims(&data.Metadata.CustomClaims); err != nil { - return nil, nil, err - } - - resolvedClaims, err := azureClaims.ResolveIndirectClaims(ctx, http.DefaultClient, accessToken) - if err != nil { - return nil, nil, err - } - - if data.Metadata.CustomClaims == nil { - if resolvedClaims != nil { - data.Metadata.CustomClaims = make(map[string]any, len(resolvedClaims)) - } - } - - if data.Metadata.CustomClaims != nil { - for _, claim := range removeAzureClaimsFromCustomClaims { - delete(data.Metadata.CustomClaims, claim) - } - } - - for k, v := range resolvedClaims { - data.Metadata.CustomClaims[k] = v - } - - return token, &data, nil -} diff --git a/internal/api/provider/azure_test.go b/internal/api/provider/azure_test.go index d15f14d48..316cb08ba 100644 --- a/internal/api/provider/azure_test.go +++ b/internal/api/provider/azure_test.go @@ -1,15 +1,6 @@ package provider -import ( - "context" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) +import "testing" func TestIsAzureIssuer(t *testing.T) { positiveExamples := []string{ @@ -36,138 +27,3 @@ func TestIsAzureIssuer(t *testing.T) { } } } - -func TestAzureResolveIndirectClaims(t *testing.T) { - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - - w.Write([]byte(`{ - "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#Collection(Edm.String)", - "value": [ - "fee2c45b-915a-4a64-b130-f4eb9e75525e", - "4fe90ae7-065a-478b-9400-e0a0e1cbd540", - "c9ee2d50-9e8a-4352-b97c-4c2c99557c22", - "e0c3beaf-eeb4-43d8-abc5-94f037a65697" - ] -}`)) - })) - - defer server.Close() - - var claims AzureIDTokenClaims - - resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token") - require.Nil(t, resolvedClaims) - require.Nil(t, err) - - claims.ClaimNames = make(map[string]string) - - resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token") - require.Nil(t, resolvedClaims) - require.Nil(t, err) - - claims.ClaimNames = map[string]string{ - "groups": "src1", - "missing-source": "src2", - "not-https": "src3", - } - claims.ClaimSources = map[string]AzureIDTokenClaimSource{ - "src1": { - Endpoint: server.URL, - }, - "src3": { - Endpoint: "http://example.com", - }, - } - - resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token") - require.NoError(t, err) - require.NotNil(t, resolvedClaims) - require.Equal(t, 1, len(resolvedClaims)) - require.Equal(t, 4, len(resolvedClaims["groups"].([]interface{}))) -} - -func TestAzureResolveIndirectClaimsFailures(t *testing.T) { - examples := []struct { - name string - urlSuffix string - statusCode int - body []byte - expectedError string - }{ - { - name: "invalid url", - urlSuffix: "\000", - expectedError: "azure: failed to parse endpoint URL \"SERVER-URL\\x00\" (resolving overage claim \"groups\"): parse \"SERVER-URL\\x00\": net/url: invalid control character in URL", - }, - { - name: "no such server", - urlSuffix: "000", - expectedError: "azure: failed to send POST request to \"SERVER-URL000\" (resolving overage claim \"groups\"): Post \"SERVER-URL000\": dial tcp: address PORT000: invalid port", - }, - { - name: "non 200 status code", - statusCode: 500, - body: []byte(`something is wrong`), - expectedError: "azure: received 500 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"something is wrong\"", - }, - { - name: "non 200 status code, non utf8 valid body", - statusCode: 201, - body: []byte{255, 255, 255, 255}, - expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"\"", - }, - { - name: "non 200 status code, empty body", - statusCode: 201, - body: []byte{}, - expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"\"", - }, - { - name: "non 200 status code, body over 2KB", - statusCode: 201, - body: []byte(strings.Repeat("x", 2*1024+1)), - expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"", - }, - { - name: "ok response, not json", - statusCode: 200, - body: []byte("not json"), - expectedError: "azure: failed to parse JSON response from POST to \"SERVER-URL\" (resolving overage claim \"groups\"): invalid character 'o' in literal null (expecting 'u')", - }, - } - - for _, example := range examples { - t.Run(example.name, func(t *testing.T) { - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "1.6", r.URL.Query().Get("api-version")) - - w.WriteHeader(example.statusCode) - - w.Write(example.body) - })) - - defer server.Close() - - u, _ := url.Parse(server.URL) - - var claims AzureIDTokenClaims - - claims.ClaimNames = map[string]string{ - "groups": "src1", - } - claims.ClaimSources = map[string]AzureIDTokenClaimSource{ - "src1": { - Endpoint: server.URL + example.urlSuffix, - }, - } - - resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token") - require.Nil(t, resolvedClaims) - require.Error(t, err) - require.Equal(t, example.expectedError, strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(err.Error(), server.URL, "SERVER-URL"), u.Port(), "PORT"), "?api-version=1.6", "")) - }) - } - -} diff --git a/internal/api/provider/oidc.go b/internal/api/provider/oidc.go index b4733df2c..7bccb451f 100644 --- a/internal/api/provider/oidc.go +++ b/internal/api/provider/oidc.go @@ -65,7 +65,7 @@ func ParseIDToken(ctx context.Context, provider *oidc.Provider, config *oidc.Con token, data, err = parseVercelMarketplaceIDToken(token) default: if IsAzureIssuer(token.Issuer) { - token, data, err = parseAzureIDToken(ctx, token, options.AccessToken) + token, data, err = parseAzureIDToken(token) } else { token, data, err = parseGenericIDToken(token) } @@ -211,6 +211,111 @@ func parseLinkedinIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData return token, &data, nil } +type AzureIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"` +} + +func (c *AzureIDTokenClaims) IsEmailVerified() bool { + emailVerified := false + + edov := c.XMicrosoftEmailDomainOwnerVerified + + // If xms_edov is not set, and an email is present or xms_edov is true, + // only then is the email regarded as verified. + // https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users + if edov == nil { + // An email is provided, but xms_edov is not -- probably not + // configured, so we must assume the email is verified as Azure + // will only send out a potentially unverified email address in + // single-tenanat apps. + emailVerified = c.Email != "" + } else { + edovBool := false + + // Azure can't be trusted with how they encode the xms_edov + // claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true. + switch v := edov.(type) { + case bool: + edovBool = v + + case string: + edovBool = v == "1" || v == "true" + + default: + edovBool = false + } + + emailVerified = c.Email != "" && edovBool + } + + return emailVerified +} + +// removeAzureClaimsFromCustomClaims contains the list of claims to be removed +// from the CustomClaims map. See: +// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference +var removeAzureClaimsFromCustomClaims = []string{ + "aud", + "iss", + "iat", + "nbf", + "exp", + "c_hash", + "at_hash", + "aio", + "nonce", + "rh", + "uti", + "jti", + "ver", + "sub", + "name", + "preferred_username", +} + +func parseAzureIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var data UserProvidedData + + var azureClaims AzureIDTokenClaims + if err := token.Claims(&azureClaims); err != nil { + return nil, nil, err + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + PreferredUsername: azureClaims.PreferredUsername, + FullName: azureClaims.Name, + CustomClaims: make(map[string]any), + } + + if azureClaims.Email != "" { + data.Emails = []Email{{ + Email: azureClaims.Email, + Verified: azureClaims.IsEmailVerified(), + Primary: true, + }} + } + + if err := token.Claims(&data.Metadata.CustomClaims); err != nil { + return nil, nil, err + } + + if data.Metadata.CustomClaims != nil { + for _, claim := range removeAzureClaimsFromCustomClaims { + delete(data.Metadata.CustomClaims, claim) + } + } + + return token, &data, nil +} + type KakaoIDTokenClaims struct { jwt.RegisteredClaims