Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: revoke refresh token on request only #766

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions handler/oauth2/revocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,47 @@ type TokenRevocationHandler struct {
TokenRevocationStorage TokenRevocationStorage
RefreshTokenStrategy RefreshTokenStrategy
AccessTokenStrategy AccessTokenStrategy

// RevokeRefreshTokenOnRequestOnly is used to indicate if the refresh token should be revoked only if
// token passed to the request is a refresh token. The default behavior revokes both the access and refresh
// tokens if the token passed to the request is either.
//
// [RFC7009 - Section 2.1] Depending on the authorization server's revocation policy, the
// revocation of a particular token may cause the revocation of related
// tokens and the underlying authorization grant. If the particular
// token is a refresh token and the authorization server supports the
// revocation of access tokens, then the authorization server SHOULD
// also invalidate all access tokens based on the same authorization
// grant (see Implementation Note). If the token passed to the request
// is an access token, the server MAY revoke the respective refresh
// token as well.
RevokeRefreshTokenOnRequestOnly bool
}

// RevokeToken implements https://tools.ietf.org/html/rfc7009#section-2.1
// The token type hint indicates which token type check should be performed first.
func (r *TokenRevocationHandler) RevokeToken(ctx context.Context, token string, tokenType fosite.TokenType, client fosite.Client) error {
actualTokenType := tokenType
discoveryFuncs := []func() (request fosite.Requester, err error){
func() (request fosite.Requester, err error) {
// Refresh token
signature := r.RefreshTokenStrategy.RefreshTokenSignature(ctx, token)
return r.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
ar, err := r.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err == nil {
actualTokenType = fosite.RefreshToken
}

return ar, err
},
func() (request fosite.Requester, err error) {
// Access token
signature := r.AccessTokenStrategy.AccessTokenSignature(ctx, token)
return r.TokenRevocationStorage.GetAccessTokenSession(ctx, signature, nil)
ar, err := r.TokenRevocationStorage.GetAccessTokenSession(ctx, signature, nil)
if err == nil {
actualTokenType = fosite.AccessToken
}

return ar, err
},
}

Expand All @@ -55,7 +81,10 @@ func (r *TokenRevocationHandler) RevokeToken(ctx context.Context, token string,
}

requestID := ar.GetID()
err1 = r.TokenRevocationStorage.RevokeRefreshToken(ctx, requestID)
if !r.RevokeRefreshTokenOnRequestOnly || actualTokenType == fosite.RefreshToken {
err1 = r.TokenRevocationStorage.RevokeRefreshToken(ctx, requestID)
}

err2 = r.TokenRevocationStorage.RevokeAccessToken(ctx, requestID)

return storeErrorsToRevocationError(err1, err2)
Expand Down
230 changes: 230 additions & 0 deletions handler/oauth2/revocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,233 @@ func TestRevokeToken(t *testing.T) {
})
}
}

func TestRevokeTokenWithRefreshTokenOnRequestOnly(t *testing.T) {
ctrl := gomock.NewController(t)
store := internal.NewMockTokenRevocationStorage(ctrl)
atStrat := internal.NewMockAccessTokenStrategy(ctrl)
rtStrat := internal.NewMockRefreshTokenStrategy(ctrl)
ar := internal.NewMockAccessRequester(ctrl)
defer ctrl.Finish()

h := TokenRevocationHandler{
TokenRevocationStorage: store,
RefreshTokenStrategy: rtStrat,
AccessTokenStrategy: atStrat,
RevokeRefreshTokenOnRequestOnly: true,
}

var token string
var tokenType fosite.TokenType

for k, c := range []struct {
description string
mock func()
expectErr error
client fosite.Client
}{
{
description: "should fail - token was issued to another client",
expectErr: fosite.ErrUnauthorizedClient,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "foo"})
},
},
{
description: "should pass - refresh token discovery first; refresh token found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)
ar.EXPECT().GetID().Return("refresh token discovery first; refresh token found")
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any())
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any())
},
},
{
description: "should pass - access token discovery first; access token found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)
ar.EXPECT().GetID().Return("access token discovery first; access token found")
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any())
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any())
},
},
{
description: "should pass - refresh token discovery first; refresh token not found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)
ar.EXPECT().GetID().Return("refresh token discovery first; refresh token not found")
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any())
},
},
{
description: "should pass - access token discovery first; access token not found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)
ar.EXPECT().GetID().Return("access token discovery first; access token not found")
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any())
},
},
{
description: "should pass - refresh token discovery first; both tokens not found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)
},
},
{
description: "should pass - access token discovery first; both tokens not found",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)
},
},
{

description: "should pass - refresh token discovery first; refresh token is inactive",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrInactiveToken)

atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)
},
},
{
description: "should pass - access token discovery first; refresh token is inactive",
expectErr: nil,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrInactiveToken)
},
},
{
description: "should fail - store error for access token get",
expectErr: fosite.ErrTemporarilyUnavailable,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("random error"))

rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)
},
},
{
description: "should fail - store error for refresh token get",
expectErr: fosite.ErrTemporarilyUnavailable,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound)

rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("random error"))
},
},
{
description: "should fail - store error for access token revoke",
expectErr: fosite.ErrTemporarilyUnavailable,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.AccessToken
atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token)
store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)

ar.EXPECT().GetID().Return("access token revoke")
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()).Return(fmt.Errorf("random error"))
},
},
{
description: "should fail - store error for refresh token revoke",
expectErr: fosite.ErrTemporarilyUnavailable,
client: &fosite.DefaultClient{ID: "bar"},
mock: func() {
token = "foo"
tokenType = fosite.RefreshToken
rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token)
store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil)

ar.EXPECT().GetID()
ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"})
store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any()).Return(fmt.Errorf("random error"))
store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()).Return(fosite.ErrNotFound)
},
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(tt *testing.T) {
c.mock()
err := h.RevokeToken(context.Background(), token, tokenType, c.client)

if c.expectErr != nil {
require.EqualError(tt, err, c.expectErr.Error())
} else {
require.NoError(tt, err)
}
})
}
}
Loading