Skip to content

Commit

Permalink
feat: add ResponseModeHandler to support custom response modes (#592)
Browse files Browse the repository at this point in the history
Closes #591
  • Loading branch information
narg95 authored Jun 11, 2021
1 parent 7644a74 commit 10ec003
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 32 deletions.
5 changes: 5 additions & 0 deletions authorize_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")

if f.ResponseModeHandler().ResponseModes().Has(ar.GetResponseMode()) {
f.ResponseModeHandler().WriteAuthorizeError(rw, ar, err)
return
}

rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.UseLegacyErrorFormat).WithExposeDebug(f.SendDebugMessagesToClients)
if !ar.IsRedirectURIValid() {
rw.Header().Set("Content-Type", "application/json;charset=UTF-8")
Expand Down
3 changes: 2 additions & 1 deletion authorize_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestWriteAuthorizeError(t *testing.T) {
err: ErrInvalidGrant,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) {
req.EXPECT().IsRedirectURIValid().Return(false)
req.EXPECT().GetResponseMode().Return(ResponseModeDefault)
rw.EXPECT().Header().Times(3).Return(header)
rw.EXPECT().WriteHeader(http.StatusBadRequest)
rw.EXPECT().Write(gomock.Any())
Expand Down Expand Up @@ -427,7 +428,7 @@ func TestWriteAuthorizeError(t *testing.T) {
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
req.EXPECT().GetResponseTypes().AnyTimes().Return(Arguments([]string{"token"}))
req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(1)
req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(2)
rw.EXPECT().Header().Times(3).Return(header)
rw.EXPECT().Write(gomock.Any()).AnyTimes()
},
Expand Down
5 changes: 5 additions & 0 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) e
case string(ResponseModeFormPost):
request.ResponseMode = ResponseModeFormPost
default:
rm := ResponseModeType(responseMode)
if f.ResponseModeHandler().ResponseModes().Has(rm) {
request.ResponseMode = ResponseModeType(rm)
break
}
return errorsx.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode))
}

Expand Down
7 changes: 6 additions & 1 deletion authorize_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ
wh.Set("Pragma", "no-cache")

redir := ar.GetRedirectURI()
switch ar.GetResponseMode() {
switch rm := ar.GetResponseMode(); rm {
case ResponseModeFormPost:
//form_post
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
Expand All @@ -60,6 +60,11 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ
URLSetFragment(redir, resp.GetParameters())
sendRedirect(redir.String(), rw)
return
default:
if f.ResponseModeHandler().ResponseModes().Has(rm) {
f.ResponseModeHandler().WriteAuthorizeResponse(rw, ar, resp)
return
}
}
}

Expand Down
1 change: 1 addition & 0 deletions compose/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func Compose(config *Config, storage interface{}, strategy interface{}, hasher f
MinParameterEntropy: config.GetMinParameterEntropy(),
UseLegacyErrorFormat: config.UseLegacyErrorFormat,
ClientAuthenticationStrategy: config.GetClientAuthenticationStrategy(),
ResponseModeHandlerExtension: config.ResponseModeHandlerExtension,
}

for _, factory := range factories {
Expand Down
3 changes: 3 additions & 0 deletions compose/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ type Config struct {

// ClientAuthenticationStrategy indicates the Strategy to authenticate client requests
ClientAuthenticationStrategy fosite.ClientAuthenticationStrategy

// ResponseModeHandlerExtension provides a handler for custom response modes
ResponseModeHandlerExtension fosite.ResponseModeHandler
}

// GetScopeStrategy returns the scope strategy to be used. Defaults to glob scope strategy.
Expand Down
11 changes: 11 additions & 0 deletions fosite.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ type Fosite struct {

// ClientAuthenticationStrategy provides an extension point to plug a strategy to authenticate clients
ClientAuthenticationStrategy ClientAuthenticationStrategy

ResponseModeHandlerExtension ResponseModeHandler
}

const MinParameterEntropy = 8
Expand All @@ -125,3 +127,12 @@ func (f *Fosite) GetMinParameterEntropy() int {
return f.MinParameterEntropy
}
}

var defaultResponseModeHandler = &DefaultResponseModeHandler{}

func (f *Fosite) ResponseModeHandler() ResponseModeHandler {
if f.ResponseModeHandlerExtension == nil {
return defaultResponseModeHandler
}
return f.ResponseModeHandlerExtension
}
109 changes: 79 additions & 30 deletions integration/authorize_form_post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package integration_test
import (
"fmt"
"net/http"
"net/url"
"strings"
"testing"

Expand All @@ -40,6 +41,15 @@ import (
"github.com/ory/fosite/compose"
)

type formPostTestCase struct {
description string
setup func()
check checkFunc
responseType string
}

type checkFunc func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string)

func TestAuthorizeFormPostResponseMode(t *testing.T) {
session := &defaultSession{
DefaultSession: &openid.DefaultSession{
Expand All @@ -49,7 +59,8 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
Headers: &jwt.Headers{},
},
}
f := compose.ComposeAllEnabled(new(compose.Config), fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey())
config := &compose.Config{ResponseModeHandlerExtension: &decoratedFormPostResponse{}}
f := compose.ComposeAllEnabled(config, fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey())
ts := mockServer(t, f, session)
defer ts.Close()

Expand All @@ -58,26 +69,21 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
defaultClient.RedirectURIs[0] = ts.URL + "/callback"
responseModeClient := &fosite.DefaultResponseModeClient{
DefaultClient: defaultClient,
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeFormPost},
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeFormPost, fosite.ResponseModeFormPost, "decorated_form_post"},
}
fositeStore.Clients["response-mode-client"] = responseModeClient
oauthClient.ClientID = "response-mode-client"

var state string
for k, c := range []struct {
description string
setup func()
check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string)
responseType string
}{
for k, c := range []formPostTestCase{
{
description: "implicit grant #1 test with form_post",
responseType: "id_token%20token",
setup: func() {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, token.TokenType)
assert.NotEmpty(t, token.AccessToken)
Expand All @@ -92,7 +98,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, iDToken)
},
Expand All @@ -103,7 +109,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
setup: func() {
state = "12345678901234567890"
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
},
Expand All @@ -115,7 +121,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, token.TokenType)
Expand All @@ -130,7 +136,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, iDToken)
Expand All @@ -146,7 +152,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, iDToken)
Expand All @@ -158,27 +164,70 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
setup: func() {
state = "12345678901234567890"
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, err["ErrorField"])
assert.NotEmpty(t, err["DescriptionField"])
},
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) {
c.setup()
authURL := strings.Replace(oauthClient.AuthCodeURL(state, goauth.SetAuthURLParam("response_mode", "form_post"), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return errors.New("Dont follow redirects")
},
}
resp, err := client.Get(authURL)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
code, state, token, iDToken, _, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body)
require.NoError(t, err)
c.check(t, state, code, iDToken, token, errResp)
})
// Test canonical form_post
t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), testFormPost(&state, false, c, oauthClient, "form_post"))

// Test decorated form_post response
c.check = decorateCheck(c.check)
t.Run(fmt.Sprintf("case=%d/description=decorated_%s", k, c.description), testFormPost(&state, true, c, oauthClient, "decorated_form_post"))
}
}

func testFormPost(state *string, customResponse bool, c formPostTestCase, oauthClient *goauth.Config, responseMode string) func(t *testing.T) {
return func(t *testing.T) {
c.setup()
authURL := strings.Replace(oauthClient.AuthCodeURL(*state, goauth.SetAuthURLParam("response_mode", responseMode), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return errors.New("Dont follow redirects")
},
}
resp, err := client.Get(authURL)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
code, state, token, iDToken, cparam, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body)
require.NoError(t, err)
c.check(t, state, code, iDToken, token, cparam, errResp)
}
}

func decorateCheck(cf checkFunc) checkFunc {
return func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
cf(t, stateFromServer, code, token, iDToken, cparam, err)
if len(err) > 0 {
assert.Contains(t, cparam, "custom_err_param")
return
}
assert.Contains(t, cparam, "custom_param")
}
}

// This test type provides an example implementation
// of a custom response mode handler.
// In this case it decorates the `form_post` response mode
// with some additional custom parameters
type decoratedFormPostResponse struct {
}

func (m *decoratedFormPostResponse) ResponseModes() fosite.ResponseModeTypes {
return fosite.ResponseModeTypes{"decorated_form_post"}
}
func (m *decoratedFormPostResponse) WriteAuthorizeResponse(rw http.ResponseWriter, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) {
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
resp.AddParameter("custom_param", "foo")
fosite.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), resp.GetParameters(), fosite.GetPostFormHTMLTemplate(fosite.Fosite{}), rw)
}
func (m *decoratedFormPostResponse) WriteAuthorizeError(rw http.ResponseWriter, ar fosite.AuthorizeRequester, err error) {
rfcerr := fosite.ErrorToRFC6749Error(err)
errors := rfcerr.ToValues()
errors.Set("state", ar.GetState())
errors.Add("custom_err_param", "bar")
fosite.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), errors, fosite.GetPostFormHTMLTemplate(fosite.Fosite{}), rw)
}
47 changes: 47 additions & 0 deletions response_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package fosite

import "net/http"

// ResponseModeHandler provides a contract for handling custom response modes
type ResponseModeHandler interface {
// ResponseModes returns a set of supported response modes handled
// by the interface implementation.
//
// In an authorize request with any of the provide response modes
// methods `WriteAuthorizeResponse` and `WriteAuthorizeError` will be
// invoked to write the successful or error authorization responses respectively.
ResponseModes() ResponseModeTypes

// WriteAuthorizeResponse writes successful responses
//
// Following headers are expected to be set by default:
// header.Set("Cache-Control", "no-store")
// header.Set("Pragma", "no-cache")
WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder)

// WriteAuthorizeError writes error responses
//
// Following headers are expected to be set by default:
// header.Set("Cache-Control", "no-store")
// header.Set("Pragma", "no-cache")
WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error)
}

type ResponseModeTypes []ResponseModeType

func (rs ResponseModeTypes) Has(item ResponseModeType) bool {
for _, r := range rs {
if r == item {
return true
}
}
return false
}

type DefaultResponseModeHandler struct{}

func (d *DefaultResponseModeHandler) ResponseModes() ResponseModeTypes { return nil }
func (d *DefaultResponseModeHandler) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) {
}
func (d *DefaultResponseModeHandler) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error) {
}

0 comments on commit 10ec003

Please sign in to comment.