Skip to content

Commit

Permalink
fix: get user first name and last name from Apple (#2331)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiggyDown authored Apr 21, 2022
1 parent 5ed4ca4 commit 4779909
Show file tree
Hide file tree
Showing 17 changed files with 108 additions and 20 deletions.
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oidc

import (
"context"
"net/url"

"golang.org/x/oauth2"

Expand All @@ -11,7 +12,7 @@ import (
type Provider interface {
Config() *Configuration
OAuth2(ctx context.Context) (*oauth2.Config, error)
Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error)
Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error)
AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption
}

Expand Down
51 changes: 46 additions & 5 deletions selfservice/strategy/oidc/provider_apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"net/url"
"time"

"github.com/form3tech-oss/jwt-go"
"github.com/golang-jwt/jwt/v4"

"github.com/pkg/errors"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -49,10 +52,10 @@ func (a *ProviderApple) newClientSecret() (string, error) {
expirationTime := time.Now().Add(5 * time.Minute)

appleToken := jwt.NewWithClaims(jwt.SigningMethodES256,
jwt.StandardClaims{
Audience: []string{a.config.IssuerURL},
ExpiresAt: expirationTime.Unix(),
IssuedAt: now.Unix(),
jwt.RegisteredClaims{
Audience: []string{"https://appleid.apple.com"},
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(now),
Issuer: a.config.TeamId,
Subject: a.config.ClientID,
})
Expand Down Expand Up @@ -101,3 +104,41 @@ func (a *ProviderApple) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {

return options
}

func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
claims, err := a.ProviderGenericOIDC.Claims(ctx, exchange, query)
if err != nil {
return claims, err
}
decodeQuery(query, claims)

return claims, nil
}

// decodeQuery decodes extra user info from Apple into the given `Claims`.
// The info is sent as an extra query parameter to the redirect URL.
// See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292
// Note that there's no way to make sure the info hasn't been tampered with.
func decodeQuery(query url.Values, claims *Claims) {
var user struct {
Name *struct {
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
} `json:"name"`
}
if err := json.Unmarshal([]byte(query.Get("user")), &user); err == nil {
if name := user.Name; name != nil {
if firstName := name.FirstName; firstName != nil {
if claims.GivenName == "" {
claims.GivenName = *firstName
}
if claims.FamilyName == "" {
claims.FamilyName = *firstName
}
}
if lastName := name.LastName; lastName != nil && claims.LastName == "" {
claims.LastName = *lastName
}
}
}
}
36 changes: 36 additions & 0 deletions selfservice/strategy/oidc/provider_apple_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package oidc

import (
"fmt"
"github.com/stretchr/testify/assert"
"net/url"
"testing"
)

func TestDecodeQuery(t *testing.T) {
query := url.Values{
"user": []string{`{"name": {"firstName": "first", "lastName": "last"}, "email": "email@email.com"}`},
}

for k, tc := range []struct {
claims *Claims
familyName string
givenName string
lastName string
}{
{claims: &Claims{}, familyName: "first", givenName: "first", lastName: "last"},
{claims: &Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"},
{claims: &Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"},
{claims: &Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
decodeQuery(query, tc.claims)
assert.Equal(t, tc.familyName, tc.claims.FamilyName)
assert.Equal(t, tc.givenName, tc.claims.GivenName)
assert.Equal(t, tc.lastName, tc.claims.LastName)
// Never extract email from the query, as the same info can be extracted and verified from the ID token.
assert.Empty(t, tc.claims.Email)
})
}

}
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/provider_auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (g *ProviderAuth0) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2(ctx)
}

func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
o, err := g.OAuth2(ctx)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"net/url"

"github.com/ory/kratos/x"

Expand Down Expand Up @@ -62,7 +63,7 @@ func (d *ProviderDiscord) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
}
}

func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ")
for _, check := range d.Config().Scope {
if !stringslice.Has(grantedScopes, check) {
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/provider_facebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (g *ProviderFacebook) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2ConfigFromEndpoint(ctx, endpoint), nil
}

func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
o, err := g.OAuth2(ctx)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_generic_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oidc

import (
"context"
"net/url"

"github.com/pkg/errors"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -97,7 +98,7 @@ func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Cont
return &claims, nil
}

func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
raw, ok := exchange.Extra("id_token").(string)
if !ok || len(raw) == 0 {
return nil, errors.WithStack(ErrIDTokenMissing)
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_github.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"net/url"

"github.com/ory/kratos/x"

Expand Down Expand Up @@ -55,7 +56,7 @@ func (g *ProviderGitHub) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",")
for _, check := range g.Config().Scope {
if !stringslice.Has(grantedScopes, check) {
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_github_app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"net/url"

"github.com/ory/kratos/x"

Expand Down Expand Up @@ -52,7 +53,7 @@ func (g *ProviderGitHubApp) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
gh := ghapi.NewClient(g.oauth2(ctx).Client(ctx, exchange))

user, _, err := gh.Users.Get(ctx, "")
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/provider_gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (g *ProviderGitLab) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2(ctx)
}

func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
o, err := g.OAuth2(ctx)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"encoding/json"
"net/url"
"strings"

"github.com/hashicorp/go-retryablehttp"
Expand Down Expand Up @@ -49,7 +50,7 @@ func (m *ProviderMicrosoft) OAuth2(ctx context.Context) (*oauth2.Config, error)
return m.oauth2ConfigFromEndpoint(ctx, endpoint), nil
}

func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
raw, ok := exchange.Extra("id_token").(string)
if !ok || len(raw) == 0 {
return nil, errors.WithStack(ErrIDTokenMissing)
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_private_net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc_test
import (
"context"
"fmt"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -77,7 +78,7 @@ func TestProviderPrivateIP(t *testing.T) {
p := tc.p(tc.c)
_, err := p.Claims(context.Background(), (&oauth2.Token{RefreshToken: "foo", Expiry: time.Now().Add(-time.Hour)}).WithExtra(map[string]interface{}{
"id_token": tc.id,
}))
}), url.Values{})
require.Error(t, err)
assert.Contains(t, fmt.Sprintf("%+v", err), tc.e)
})
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_slack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"net/url"

"github.com/ory/herodot"

Expand Down Expand Up @@ -57,7 +58,7 @@ func (d *ProviderSlack) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",")
for _, check := range d.Config().Scope {
if !stringslice.Has(grantedScopes, check) {
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_spotify.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"net/url"

"golang.org/x/oauth2/spotify"

Expand Down Expand Up @@ -55,7 +56,7 @@ func (g *ProviderSpotify) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ")
for _, check := range g.Config().Scope {
if !stringslice.Has(grantedScopes, check) {
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_vk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"encoding/json"
"net/url"
"strconv"

"github.com/hashicorp/go-retryablehttp"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (g *ProviderVK) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2(ctx), nil
}

func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {

o, err := g.OAuth2(ctx)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/provider_yandex.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"encoding/json"
"net/url"

"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
Expand Down Expand Up @@ -53,7 +54,7 @@ func (g *ProviderYandex) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2(ctx), nil
}

func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token) (*Claims, error) {
func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
o, err := g.OAuth2(ctx)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ func (s *Strategy) handleCallback(w http.ResponseWriter, r *http.Request, ps htt
return
}

claims, err := provider.Claims(r.Context(), token)
claims, err := provider.Claims(r.Context(), token, r.URL.Query())
if err != nil {
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
Expand Down

0 comments on commit 4779909

Please sign in to comment.