Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Token Reuse Detection #567

Merged
merged 2 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex
refresh := request.GetRequestForm().Get("refresh_token")
signature := c.RefreshTokenStrategy.RefreshTokenSignature(refresh)
originalRequest, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, request.GetSession())
if errors.Is(err, fosite.ErrNotFound) {
if errors.Is(err, fosite.ErrInactiveToken) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the idea is that the storage would return a new error code ErrInactiveToken when a refresh token was used or revoked - am I correct with that assumption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, same way it does for auth codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice makes sense!

// Detected refresh token reuse
if rErr := c.handleRefreshTokenReuse(ctx, signature, originalRequest); rErr != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(rErr).WithDebug(rErr.Error()))
}

return errorsx.WithStack(fosite.ErrInactiveToken.WithWrap(err).WithDebug(err.Error()))
} else if errors.Is(err, fosite.ErrNotFound) {
return errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebugf("The refresh token has not been found: %s", err.Error()))
} else if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
Expand Down Expand Up @@ -138,22 +145,22 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeRefreshToken(ctx, ts.GetID()); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())

if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
}

if err := c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
}

responder.SetAccessToken(accessToken)
Expand All @@ -163,16 +170,51 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
responder.SetExtra("refresh_token", refreshToken)

if err := storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, false, c.TokenRevocationStorage, err)
return c.handleRefreshTokenEndpointStorageError(ctx, false, err)
}

return nil
}

// Reference: https://tools.ietf.org/html/rfc6819#section-5.2.2.3
//
// The basic idea is to change the refresh token
// value with every refresh request in order to detect attempts to
// obtain access tokens using old refresh tokens. Since the
// authorization server cannot determine whether the attacker or the
// legitimate client is trying to access, in case of such an access
// attempt the valid refresh token and the access authorization
// associated with it are both revoked.
//
func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) error {
ctx, err := storage.MaybeBeginTx(ctx, c.TokenRevocationStorage)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

if err := c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeRefreshToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeAccessToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
}

if err := storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, false, err)
}

return nil
}

func handleRefreshTokenEndpointResponseStorageError(ctx context.Context, rollback bool, store TokenRevocationStorage, storageErr error) (err error) {
func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx context.Context, rollback bool, storageErr error) (err error) {
defer func() {
if rollback {
if rbErr := storage.MaybeRollbackTx(ctx, store); rbErr != nil {
if rbErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rbErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(rbErr).WithDebug(rbErr.Error()))
}
}
Expand Down
115 changes: 115 additions & 0 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,36 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) {
assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(fosite.RefreshToken))
},
},
{
description: "should deny access on token reuse",
setup: func() {
areq.GrantTypes = fosite.Arguments{"refresh_token"}
areq.Client = &fosite.DefaultClient{
ID: "foo",
GrantTypes: fosite.Arguments{"refresh_token"},
Scopes: []string{"foo", "bar", "offline"},
}

token, sig, err := strategy.GenerateRefreshToken(nil, nil)
require.NoError(t, err)

areq.Form.Add("refresh_token", token)
req := &fosite.Request{
Client: areq.Client,
GrantedScope: fosite.Arguments{"foo", "offline"},
RequestedScope: fosite.Arguments{"foo", "bar", "offline"},
Session: sess,
Form: url.Values{"foo": []string{"bar"}},
RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour),
}
err = store.CreateRefreshTokenSession(nil, sig, req)
require.NoError(t, err)

err = store.RevokeRefreshToken(nil, req.ID)
require.NoError(t, err)
},
expectErr: fosite.ErrInactiveToken,
},
} {
t.Run("case="+c.description, func(t *testing.T) {
h = RefreshTokenGrantHandler{
Expand Down Expand Up @@ -261,6 +291,91 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) {
}
}

func TestRefreshFlowTransactional_HandleTokenEndpointRequest(t *testing.T) {
var mockTransactional *internal.MockTransactional
var mockRevocationStore *internal.MockTokenRevocationStorage
request := fosite.NewAccessRequest(&fosite.DefaultSession{})
propagatedContext := context.Background()

type transactionalStore struct {
storage.Transactional
TokenRevocationStorage
}

for _, testCase := range []struct {
description string
setup func()
expectError error
}{
{
description: "should revoke session on token reuse",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
request.Client = &fosite.DefaultClient{
ID: "foo",
GrantTypes: fosite.Arguments{"refresh_token"},
}
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(request, fosite.ErrInactiveToken).
Times(1)
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
DeleteRefreshTokenSession(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockTransactional.
EXPECT().
Commit(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInactiveToken,
},
} {
t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockTransactional = internal.NewMockTransactional(ctrl)
mockRevocationStore = internal.NewMockTokenRevocationStorage(ctrl)
testCase.setup()

handler := RefreshTokenGrantHandler{
TokenRevocationStorage: transactionalStore{
mockTransactional,
mockRevocationStore,
},
AccessTokenStrategy: &hmacshaStrategy,
RefreshTokenStrategy: &hmacshaStrategy,
AccessTokenLifespan: time.Hour,
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
}

if err := handler.HandleTokenEndpointRequest(propagatedContext, request); testCase.expectError != nil {
assert.EqualError(t, err, testCase.expectError.Error())
}
})
}
}

func TestRefreshFlow_PopulateTokenEndpointResponse(t *testing.T) {
var areq *fosite.AccessRequest
var aresp *fosite.AccessResponse
Expand Down
2 changes: 1 addition & 1 deletion integration/helper_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ var fositeStore = &storage.MemoryStore{
AuthorizeCodes: map[string]storage.StoreAuthorizeCode{},
PKCES: map[string]fosite.Requester{},
AccessTokens: map[string]fosite.Requester{},
RefreshTokens: map[string]fosite.Requester{},
RefreshTokens: map[string]storage.StoreRefreshToken{},
IDSessions: map[string]fosite.Requester{},
AccessTokenRequestIDs: map[string]string{},
RefreshTokenRequestIDs: map[string]string{},
Expand Down
28 changes: 17 additions & 11 deletions storage/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type MemoryStore struct {
AuthorizeCodes map[string]StoreAuthorizeCode
IDSessions map[string]fosite.Requester
AccessTokens map[string]fosite.Requester
RefreshTokens map[string]fosite.Requester
RefreshTokens map[string]StoreRefreshToken
PKCES map[string]fosite.Requester
Users map[string]MemoryUserRelation
BlacklistedJTIs map[string]time.Time
Expand Down Expand Up @@ -86,7 +86,7 @@ func NewMemoryStore() *MemoryStore {
AuthorizeCodes: make(map[string]StoreAuthorizeCode),
IDSessions: make(map[string]fosite.Requester),
AccessTokens: make(map[string]fosite.Requester),
RefreshTokens: make(map[string]fosite.Requester),
RefreshTokens: make(map[string]StoreRefreshToken),
PKCES: make(map[string]fosite.Requester),
Users: make(map[string]MemoryUserRelation),
AccessTokenRequestIDs: make(map[string]string),
Expand All @@ -101,6 +101,11 @@ type StoreAuthorizeCode struct {
fosite.Requester
}

type StoreRefreshToken struct {
active bool
fosite.Requester
}

func NewExampleStore() *MemoryStore {
return &MemoryStore{
IDSessions: make(map[string]fosite.Requester),
Expand Down Expand Up @@ -132,7 +137,7 @@ func NewExampleStore() *MemoryStore {
},
AuthorizeCodes: map[string]StoreAuthorizeCode{},
AccessTokens: map[string]fosite.Requester{},
RefreshTokens: map[string]fosite.Requester{},
RefreshTokens: map[string]StoreRefreshToken{},
PKCES: map[string]fosite.Requester{},
AccessTokenRequestIDs: map[string]string{},
RefreshTokenRequestIDs: map[string]string{},
Expand Down Expand Up @@ -311,7 +316,7 @@ func (s *MemoryStore) CreateRefreshTokenSession(_ context.Context, signature str
s.refreshTokensMutex.Lock()
defer s.refreshTokensMutex.Unlock()

s.RefreshTokens[signature] = req
s.RefreshTokens[signature] = StoreRefreshToken{active: true, Requester: req}
s.RefreshTokenRequestIDs[req.GetID()] = signature
return nil
}
Expand All @@ -324,6 +329,9 @@ func (s *MemoryStore) GetRefreshTokenSession(_ context.Context, signature string
if !ok {
return nil, fosite.ErrNotFound
}
if !rel.active {
return rel, fosite.ErrInactiveToken
}
return rel, nil
}

Expand Down Expand Up @@ -354,14 +362,12 @@ func (s *MemoryStore) RevokeRefreshToken(ctx context.Context, requestID string)
defer s.refreshTokenRequestIDsMutex.Unlock()

if signature, exists := s.RefreshTokenRequestIDs[requestID]; exists {
err1 := s.DeleteRefreshTokenSession(ctx, signature)
err2 := s.DeleteAccessTokenSession(ctx, signature)
if err1 != nil {
return err1
}
if err2 != nil {
return err2
rel, ok := s.RefreshTokens[signature]
if !ok {
return fosite.ErrNotFound
}
rel.active = false
s.RefreshTokens[signature] = rel
}
return nil
}
Expand Down