Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hack/test.env
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ GOTRUE_EXTERNAL_SNAPCHAT_ENABLED=true
GOTRUE_EXTERNAL_SNAPCHAT_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_SNAPCHAT_SECRET=testsecret
GOTRUE_EXTERNAL_SNAPCHAT_REDIRECT_URI=https://identity.services.netlify.com/callback
GOTRUE_EXTERNAL_SNAPCHAT_EMAIL_OPTIONAL=true
GOTRUE_EXTERNAL_SPOTIFY_ENABLED=true
GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_SPOTIFY_SECRET=testsecret
Expand Down
12 changes: 6 additions & 6 deletions internal/api/apierrors/errorcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ const (
ErrorCodeMFAWebAuthnVerifyDisabled ErrorCode = "mfa_webauthn_verify_not_enabled"
ErrorCodeMFAVerifiedFactorExists ErrorCode = "mfa_verified_factor_exists"
//#nosec G101 -- Not a secret value.
ErrorCodeInvalidCredentials ErrorCode = "invalid_credentials"
ErrorCodeEmailAddressNotAuthorized ErrorCode = "email_address_not_authorized"
ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid"
ErrorCodeWeb3ProviderDisabled ErrorCode = "web3_provider_disabled"
ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain"

ErrorCodeInvalidCredentials ErrorCode = "invalid_credentials"
ErrorCodeEmailAddressNotAuthorized ErrorCode = "email_address_not_authorized"
ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid"
ErrorCodeWeb3ProviderDisabled ErrorCode = "web3_provider_disabled"
ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain"
ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled"
ErrorCodeEmailAddressNotProvided ErrorCode = "email_address_not_provided"
)
58 changes: 34 additions & 24 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ func (c contextKey) String() string {
}

const (
tokenKey = contextKey("jwt")
inviteTokenKey = contextKey("invite_token")
signatureKey = contextKey("signature")
externalProviderTypeKey = contextKey("external_provider_type")
userKey = contextKey("user")
targetUserKey = contextKey("target_user")
factorKey = contextKey("factor")
sessionKey = contextKey("session")
externalReferrerKey = contextKey("external_referrer")
functionHooksKey = contextKey("function_hooks")
adminUserKey = contextKey("admin_user")
oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
oauthVerifierKey = contextKey("oauth_verifier")
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
externalProviderTypeKey = contextKey("external_provider_type")
externalProviderEmailOptionalKey = contextKey("external_provider_allow_no_email")

tokenKey = contextKey("jwt")
inviteTokenKey = contextKey("invite_token")
signatureKey = contextKey("signature")
userKey = contextKey("user")
targetUserKey = contextKey("target_user")
factorKey = contextKey("factor")
sessionKey = contextKey("session")
externalReferrerKey = contextKey("external_referrer")
functionHooksKey = contextKey("function_hooks")
adminUserKey = contextKey("admin_user")
oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
oauthVerifierKey = contextKey("oauth_verifier")
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -152,18 +154,26 @@ func getInviteToken(ctx context.Context) string {
}

// withExternalProviderType adds the provided request ID to the context.
func withExternalProviderType(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, externalProviderTypeKey, id)
func withExternalProviderType(ctx context.Context, id string, emailOptional bool) context.Context {
return context.WithValue(context.WithValue(ctx, externalProviderTypeKey, id), externalProviderEmailOptionalKey, emailOptional)
}

// getExternalProviderType reads the request ID from the context.
func getExternalProviderType(ctx context.Context) string {
obj := ctx.Value(externalProviderTypeKey)
if obj == nil {
return ""
// getExternalProviderType returns the provider type and whether user data without email address should be allowed.
func getExternalProviderType(ctx context.Context) (string, bool) {
idValue := ctx.Value(externalProviderTypeKey)
emailOptionalValue := ctx.Value(externalProviderEmailOptionalKey)

id, okID := idValue.(string)
if !okID {
return "", false
}

return obj.(string)
emailOptional, okEmailOptional := emailOptionalValue.(bool)
if !okEmailOptional {
return "", false
}

return id, emailOptional
}

func withExternalReferrer(ctx context.Context, token string) context.Context {
Expand Down
117 changes: 74 additions & 43 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type ExternalProviderClaims struct {
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
}

// ExternalProviderRedirect redirects the request to the oauth provider
Expand All @@ -55,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
codeChallenge := query.Get("code_challenge")
codeChallengeMethod := query.Get("code_challenge_method")

p, err := a.Provider(ctx, providerType, scopes)
p, pConfig, err := a.Provider(ctx, providerType, scopes)
if err != nil {
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err)
}
Expand Down Expand Up @@ -96,10 +97,11 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
SiteURL: config.SiteURL,
InstanceID: uuid.Nil.String(),
},
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
EmailOptional: pConfig.EmailOptional,
}

if linkingTargetUser != nil {
Expand Down Expand Up @@ -144,7 +146,7 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e

func (a *API) handleOAuthCallback(r *http.Request) (*OAuthProviderData, error) {
ctx := r.Context()
providerType := getExternalProviderType(ctx)
providerType, _ := getExternalProviderType(ctx)

var oAuthResponseData *OAuthProviderData
var err error
Expand All @@ -168,16 +170,18 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var grantParams models.GrantParams
grantParams.FillGrantParams(r)

providerType := getExternalProviderType(ctx)
providerType, emailOptional := getExternalProviderType(ctx)
data, err := a.handleOAuthCallback(r)
if err != nil {
return err
}

userData := data.userData
if len(userData.Emails) <= 0 {

if len(userData.Emails) == 0 && !emailOptional {
return apierrors.NewInternalServerError("Error getting user email from external provider")
}

userData.Metadata.EmailVerified = false
for _, email := range userData.Emails {
if email.Primary {
Expand Down Expand Up @@ -226,7 +230,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
return terr
}
} else {
if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType); terr != nil {
if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType, emailOptional); terr != nil {
return terr
}
}
Expand Down Expand Up @@ -285,7 +289,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
return nil
}

func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string) (*models.User, error) {
func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string, emailOptional bool) (*models.User, error) {
ctx := r.Context()
aud := a.requestAud(ctx, r)
config := a.config
Expand Down Expand Up @@ -378,8 +382,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned")
}

// TODO(hf): Expand this boolean with all providers that may not have emails (like X/Twitter, Discord).
hasEmails := providerType != "web3" // intentionally not using len(userData.Emails) != 0 for better backward compatibility control
hasEmails := providerType != "web3" && !(emailOptional && decision.CandidateEmail.Email == "")

if hasEmails && !user.IsConfirmed() {
// The user may have other unconfirmed email + password
Expand All @@ -400,21 +403,19 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
}
} else {
// Some providers, like web3 don't have email data.
// Treat these as if a confirmation email has been
// sent, although the user will be created without an
// email address.
emailConfirmationSent := false
if decision.CandidateEmail.Email != "" {
if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil {
return nil, terr
}
emailConfirmationSent = true
}

if !config.Mailer.AllowUnverifiedEmailSignIns {
if emailConfirmationSent {
return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
}

return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
}
}
Expand Down Expand Up @@ -564,67 +565,97 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C
}
ctx = withTargetUser(ctx, u)
}
ctx = withExternalProviderType(ctx, claims.Provider)
ctx = withExternalProviderType(ctx, claims.Provider, claims.EmailOptional)
return withSignature(ctx, state), nil
}

// Provider returns a Provider interface for the given name.
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) {
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) {
config := a.config
name = strings.ToLower(name)

var err error
var p provider.Provider
var pConfig conf.OAuthProviderConfiguration

switch name {
case "apple":
return provider.NewAppleProvider(ctx, config.External.Apple)
pConfig = config.External.Apple
p, err = provider.NewAppleProvider(ctx, pConfig)
case "azure":
return provider.NewAzureProvider(config.External.Azure, scopes)
pConfig = config.External.Azure
p, err = provider.NewAzureProvider(pConfig, scopes)
case "bitbucket":
return provider.NewBitbucketProvider(config.External.Bitbucket)
pConfig = config.External.Bitbucket
p, err = provider.NewBitbucketProvider(pConfig)
case "discord":
return provider.NewDiscordProvider(config.External.Discord, scopes)
pConfig = config.External.Discord
p, err = provider.NewDiscordProvider(pConfig, scopes)
case "facebook":
return provider.NewFacebookProvider(config.External.Facebook, scopes)
pConfig = config.External.Facebook
p, err = provider.NewFacebookProvider(pConfig, scopes)
case "figma":
return provider.NewFigmaProvider(config.External.Figma, scopes)
pConfig = config.External.Figma
p, err = provider.NewFigmaProvider(pConfig, scopes)
case "fly":
return provider.NewFlyProvider(config.External.Fly, scopes)
pConfig = config.External.Fly
p, err = provider.NewFlyProvider(pConfig, scopes)
case "github":
return provider.NewGithubProvider(config.External.Github, scopes)
pConfig = config.External.Github
p, err = provider.NewGithubProvider(pConfig, scopes)
case "gitlab":
return provider.NewGitlabProvider(config.External.Gitlab, scopes)
pConfig = config.External.Gitlab
p, err = provider.NewGitlabProvider(pConfig, scopes)
case "google":
return provider.NewGoogleProvider(ctx, config.External.Google, scopes)
pConfig = config.External.Google
p, err = provider.NewGoogleProvider(ctx, pConfig, scopes)
case "kakao":
return provider.NewKakaoProvider(config.External.Kakao, scopes)
pConfig = config.External.Kakao
p, err = provider.NewKakaoProvider(pConfig, scopes)
case "keycloak":
return provider.NewKeycloakProvider(config.External.Keycloak, scopes)
pConfig = config.External.Keycloak
p, err = provider.NewKeycloakProvider(pConfig, scopes)
case "linkedin":
return provider.NewLinkedinProvider(config.External.Linkedin, scopes)
pConfig = config.External.Linkedin
p, err = provider.NewLinkedinProvider(pConfig, scopes)
case "linkedin_oidc":
return provider.NewLinkedinOIDCProvider(config.External.LinkedinOIDC, scopes)
pConfig = config.External.LinkedinOIDC
p, err = provider.NewLinkedinOIDCProvider(pConfig, scopes)
case "notion":
return provider.NewNotionProvider(config.External.Notion)
pConfig = config.External.Notion
p, err = provider.NewNotionProvider(pConfig)
case "snapchat":
return provider.NewSnapchatProvider(config.External.Snapchat, scopes)
pConfig = config.External.Snapchat
p, err = provider.NewSnapchatProvider(pConfig, scopes)
case "spotify":
return provider.NewSpotifyProvider(config.External.Spotify, scopes)
pConfig = config.External.Spotify
p, err = provider.NewSpotifyProvider(pConfig, scopes)
case "slack":
return provider.NewSlackProvider(config.External.Slack, scopes)
pConfig = config.External.Slack
p, err = provider.NewSlackProvider(pConfig, scopes)
case "slack_oidc":
return provider.NewSlackOIDCProvider(config.External.SlackOIDC, scopes)
pConfig = config.External.SlackOIDC
p, err = provider.NewSlackOIDCProvider(pConfig, scopes)
case "twitch":
return provider.NewTwitchProvider(config.External.Twitch, scopes)
pConfig = config.External.Twitch
p, err = provider.NewTwitchProvider(pConfig, scopes)
case "twitter":
return provider.NewTwitterProvider(config.External.Twitter, scopes)
pConfig = config.External.Twitter
p, err = provider.NewTwitterProvider(pConfig, scopes)
case "vercel_marketplace":
return provider.NewVercelMarketplaceProvider(config.External.VercelMarketplace, scopes)
pConfig = config.External.VercelMarketplace
p, err = provider.NewVercelMarketplaceProvider(pConfig, scopes)
case "workos":
return provider.NewWorkOSProvider(config.External.WorkOS)
pConfig = config.External.WorkOS
p, err = provider.NewWorkOSProvider(pConfig)
case "zoom":
return provider.NewZoomProvider(config.External.Zoom)
pConfig = config.External.Zoom
p, err = provider.NewZoomProvider(pConfig)
default:
return nil, fmt.Errorf("Provider %s could not be found", name)
return nil, pConfig, fmt.Errorf("Provider %s could not be found", name)
}

return p, pConfig, err
}

func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
Expand Down
15 changes: 8 additions & 7 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
)
Expand Down Expand Up @@ -69,7 +70,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing")
}

oAuthProvider, err := a.OAuthProvider(ctx, providerType)
oAuthProvider, _, err := a.OAuthProvider(ctx, providerType)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
}
Expand Down Expand Up @@ -111,7 +112,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
}

func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) {
oAuthProvider, err := a.OAuthProvider(ctx, providerType)
oAuthProvider, _, err := a.OAuthProvider(ctx, providerType)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
}
Expand Down Expand Up @@ -141,16 +142,16 @@ func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthPr
}

// OAuthProvider returns the corresponding oauth provider as an OAuthProvider interface
func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthProvider, error) {
providerCandidate, err := a.Provider(ctx, name, "")
func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthProvider, conf.OAuthProviderConfiguration, error) {
providerCandidate, pConfig, err := a.Provider(ctx, name, "")
if err != nil {
return nil, err
return nil, pConfig, err
}

switch p := providerCandidate.(type) {
case provider.OAuthProvider:
return p, nil
return p, pConfig, nil
default:
return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name)
return nil, pConfig, fmt.Errorf("Provider %v cannot be used for OAuth", name)
}
}
Loading