Skip to content

Commit

Permalink
fix: support allowed_cors_origins with client_secret_post (#3457)
Browse files Browse the repository at this point in the history
Closes #3456

Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
apexskier and aeneasr authored Mar 3, 2023
1 parent 97ac03a commit ffe4943
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
18 changes: 14 additions & 4 deletions x/oauth2cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,18 @@ func Middleware(
return true
}

username, _, ok := r.BasicAuth()
if !ok || username == "" {
var clientID string

// if the client uses client_secret_post auth it will provide its client ID in form data
clientID = r.PostFormValue("client_id")

// if the client uses client_secret_basic auth the client ID will be the username component
if clientID == "" {
clientID, _, _ = r.BasicAuth()
}

// otherwise, this may be a bearer auth request, in which case we can introspect the token
if clientID == "" {
token := fosite.AccessTokenFromRequest(r)
if token == "" {
return false
Expand All @@ -94,10 +104,10 @@ func Middleware(
return false
}

username = ar.GetClient().GetID()
clientID = ar.GetClient().GetID()
}

cl, err := reg.ClientManager().GetConcreteClient(ctx, username)
cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID)
if err != nil {
return false
}
Expand Down
36 changes: 35 additions & 1 deletion x/oauth2cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
package oauth2cors_test

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -37,6 +40,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
header http.Header
expectHeader http.Header
method string
body io.Reader
}{
{
d: "should ignore when disabled",
Expand All @@ -55,6 +59,36 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}},
expectHeader: http.Header{"Vary": {"Origin"}},
},
{
d: "should reject when post auth client exists but origin not allowed",
prep: func(t *testing.T, r driver.Registry) {
r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true)
r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

// Ignore unique violations
_ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
expectHeader: http.Header{"Vary": {"Origin"}},
method: http.MethodPost,
body: bytes.NewBufferString(url.Values{"client_id": []string{"foo-2"}}.Encode()),
},
{
d: "should accept when post auth client exists and origin allowed",
prep: func(t *testing.T, r driver.Registry) {
r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true)
r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

// Ignore unique violations
_ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
method: http.MethodPost,
body: bytes.NewBufferString(url.Values{"client_id": {"foo-3"}}.Encode()),
},
{
d: "should reject when basic auth client exists but origin not allowed",
prep: func(t *testing.T, r driver.Registry) {
Expand Down Expand Up @@ -237,7 +271,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
if tc.method != "" {
method = tc.method
}
req, err := http.NewRequest(method, "http://foobar.com/", nil)
req, err := http.NewRequest(method, "http://foobar.com/", tc.body)
require.NoError(t, err)
for k := range tc.header {
req.Header.Set(k, tc.header.Get(k))
Expand Down

0 comments on commit ffe4943

Please sign in to comment.