Skip to content

Commit

Permalink
refresh token rotation
Browse files Browse the repository at this point in the history
Update refresh token flow to revoke old refresh token and generates a new one.

Fixes #519
  • Loading branch information
rsoletob committed Aug 16, 2016
1 parent 4429570 commit c91b37a
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 124 deletions.
187 changes: 115 additions & 72 deletions db/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,108 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
}

func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
return r.create(nil, userID, clientID, connectorID, scopes)
}

func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
return r.verify(nil, clientID, token)
}

func (r *refreshTokenRepo) Revoke(userID, token string) error {
tx, err := r.begin()
if err != nil {
return err
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
defer tx.Rollback()
if err := r.revoke(tx, userID, token); err != nil {
return err
}

// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
return tx.Commit()
}

func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) {
// Verify
userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken)
if err != nil {
return "", err
}

payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
// Revoke old refresh token
tx, err := r.begin()
if err != nil {
return "", err
}

record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
defer tx.Rollback()
if err := r.revoke(tx, userID, oldToken); err != nil {
return "", err
}

if err := r.executor(nil).Insert(record); err != nil {
// Renew refresh token
newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes)
if err != nil {
return "", err
}

return buildToken(record.ID, tokenPayload), nil
return newRefreshToken, tx.Commit()
}

func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}

func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Client, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}

c := make([]client.Client, len(clients))
for i, client := range clients {
ident, err := client.Client()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}

return c, nil
}

func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}

if result == nil {
return nil, refresh.ErrorInvalidToken
}

record, ok := result.(*refreshTokenModel)
if !ok {
log.Errorf("expected refreshTokenModel but found %v", reflect.TypeOf(result))
return nil, errors.New("unrecognized model")
}
return record, nil
}

func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)

if err != nil {
return
}

record, err := r.get(nil, tokenID)
record, err := r.get(tx, tokenID)
if err != nil {
return
}
Expand All @@ -140,6 +201,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID s
return "", "", nil, refresh.ErrorInvalidClientID
}

// Check if the hash of token received is the same stored in database
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
Expand All @@ -152,17 +214,46 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID s
return record.UserID, record.ConnectorID, scopes, nil
}

func (r *refreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}

// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return err
return "", err
}

tx, err := r.begin()
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}

record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}

if err := r.executor(tx).Insert(record); err != nil {
return "", err
}

return buildToken(record.ID, tokenPayload), nil
}

func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
defer tx.Rollback()

exec := r.executor(tx)
record, err := r.get(tx, tokenID)
if err != nil {
Expand All @@ -185,53 +276,5 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return refresh.ErrorInvalidToken
}

return tx.Commit()
}

func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}

func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Client, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}

c := make([]client.Client, len(clients))
for i, client := range clients {
ident, err := client.Client()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}

return c, nil
}

func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}

if result == nil {
return nil, refresh.ErrorInvalidToken
}

record, ok := result.(*refreshTokenModel)
if !ok {
log.Errorf("expected refreshTokenModel but found %v", reflect.TypeOf(result))
return nil, errors.New("unrecognized model")
}
return record, nil
return nil
}
3 changes: 3 additions & 0 deletions refresh/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error

// Revoke old refresh token and generates a new one
RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error)

// RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
RevokeTokensForClient(userID, clientID string) error

Expand Down
2 changes: 1 addition & 1 deletion server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return
}
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil {
writeTokenError(w, err, state)
return
Expand Down
42 changes: 24 additions & 18 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ type OIDCServer interface {

ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)

// RefreshToken takes a previously generated refresh token and returns a new ID token
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error)

KillSession(string) error

Expand Down Expand Up @@ -567,34 +567,34 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
return jwt, refreshToken, nil
}

func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
}

userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
case refresh.ErrorInvalidToken:
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
case refresh.ErrorInvalidClientID:
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
default:
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}

Expand All @@ -603,27 +603,27 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
Expand All @@ -635,18 +635,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
}

signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

now := time.Now()
Expand All @@ -666,12 +666,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
if err != nil {
log.Errorf("Failed to generate new refresh token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

log.Infof("New token sent: clientID=%s", creds.ID)

return jwt, nil
return jwt, refreshToken, nil
}

func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
Expand Down
Loading

0 comments on commit c91b37a

Please sign in to comment.