Skip to content

Commit

Permalink
feat: hot-reload Oauth2 CORS settings (#3537)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Jun 13, 2023
1 parent 898aa00 commit a8ecf80
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 144 deletions.
2 changes: 2 additions & 0 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package driver

import (
"context"
"net/http"

"go.opentelemetry.io/otel/trace"

Expand Down Expand Up @@ -74,6 +75,7 @@ type Registry interface {
ConsentHandler() *consent.Handler
OAuth2Handler() *oauth2.Handler
HealthHandler() *healthx.Handler
OAuth2AwareMiddleware() func(h http.Handler) http.Handler

OAuth2HMACStrategy() *foauth2.HMACSHAStrategy
WithOAuth2Provider(f fosite.OAuth2Provider)
Expand Down
8 changes: 4 additions & 4 deletions driver/registry_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ func (m *RegistryBase) WithBuildInfo(version, hash, date string) Registry {
return m.r
}

func (m *RegistryBase) OAuth2AwareMiddleware(ctx context.Context) func(h http.Handler) http.Handler {
func (m *RegistryBase) OAuth2AwareMiddleware() func(h http.Handler) http.Handler {
if m.oa2mw == nil {
m.oa2mw = oauth2cors.Middleware(ctx, m.r)
m.oa2mw = oauth2cors.Middleware(m.r)
}
return m.oa2mw
}
Expand Down Expand Up @@ -153,9 +153,9 @@ func (m *RegistryBase) RegisterRoutes(ctx context.Context, admin *httprouterx.Ro
admin.Handler("GET", prometheus.MetricsPrometheusPath, promhttp.Handler())

m.ConsentHandler().SetRoutes(admin)
m.KeyHandler().SetRoutes(admin, public, m.OAuth2AwareMiddleware(ctx))
m.KeyHandler().SetRoutes(admin, public, m.OAuth2AwareMiddleware())
m.ClientHandler().SetRoutes(admin, public)
m.OAuth2Handler().SetRoutes(admin, public, m.OAuth2AwareMiddleware(ctx))
m.OAuth2Handler().SetRoutes(admin, public, m.OAuth2AwareMiddleware())
m.JWTGrantHandler().SetRoutes(admin)
}

Expand Down
195 changes: 104 additions & 91 deletions x/oauth2cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package oauth2cors

import (
"context"
"fmt"
"net/http"
"strings"

Expand All @@ -21,113 +19,128 @@ import (
)

func Middleware(
ctx context.Context,
reg interface {
x.RegistryLogger
oauth2.Registry
client.Registry
}) func(h http.Handler) http.Handler {
opts, enabled := reg.Config().CORS(ctx, config.PublicInterface)
if !enabled {
return func(h http.Handler) http.Handler {
return h
}
}

var alwaysAllow = len(opts.AllowedOrigins) == 0
var patterns []glob.Glob
for _, o := range opts.AllowedOrigins {
if o == "*" {
alwaysAllow = true
}
// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if splittedO := strings.Split(o, "://"); len(splittedO) != 1 && splittedO[1] == "*" {
o = fmt.Sprintf("%s://**", splittedO[0])
}
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
reg.Logger().WithError(err).Fatalf("Unable to parse cors origin: %s", o)
}

patterns = append(patterns, g)
}

options := cors.Options{
AllowedOrigins: opts.AllowedOrigins,
AllowedMethods: opts.AllowedMethods,
AllowedHeaders: opts.AllowedHeaders,
ExposedHeaders: opts.ExposedHeaders,
MaxAge: opts.MaxAge,
AllowCredentials: opts.AllowCredentials,
OptionsPassthrough: opts.OptionsPassthrough,
Debug: opts.Debug,
AllowOriginRequestFunc: func(r *http.Request, origin string) bool {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if alwaysAllow {
return true
}

origin = strings.ToLower(origin)
for _, p := range patterns {
if p.Match(origin) {
return true
}
opts, enabled := reg.Config().CORS(ctx, config.PublicInterface)
if !enabled {
reg.Logger().Debug("not enhancing CORS per client, as CORS is disabled")
h.ServeHTTP(w, r)
return
}

// pre-flight requests do not contain credentials (cookies, HTTP authorization)
// so we return true in all cases here.
if r.Method == http.MethodOptions {
return true
}

var clientID string

// if the client uses client_secret_post auth it will provide its client ID in form data
clientID = r.PostFormValue("client_id")

// if the client uses client_secret_basic auth the client ID will be the username component
if clientID == "" {
clientID, _, _ = r.BasicAuth()
}

// otherwise, this may be a bearer auth request, in which case we can introspect the token
if clientID == "" {
token := fosite.AccessTokenFromRequest(r)
if token == "" {
return false
alwaysAllow := len(opts.AllowedOrigins) == 0
patterns := make([]glob.Glob, 0, len(opts.AllowedOrigins))
for _, o := range opts.AllowedOrigins {
if o == "*" {
alwaysAllow = true
break
}

session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims(ctx))
_, ar, err := reg.OAuth2Provider().IntrospectToken(ctx, token, fosite.AccessToken, session)
// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if scheme, rest, found := strings.Cut(o, "://"); found && rest == "*" {
o = scheme + "://**"
}
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
reg.Logger().WithError(err).WithField("pattern", o).Error("Unable to parse CORS origin, ignoring it")
continue
}

clientID = ar.GetClient().GetID()
patterns = append(patterns, g)
}

cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID)
if err != nil {
return false
}
options := cors.Options{
AllowedOrigins: opts.AllowedOrigins,
AllowedMethods: opts.AllowedMethods,
AllowedHeaders: opts.AllowedHeaders,
ExposedHeaders: opts.ExposedHeaders,
MaxAge: opts.MaxAge,
AllowCredentials: opts.AllowCredentials,
OptionsPassthrough: opts.OptionsPassthrough,
Debug: opts.Debug,
AllowOriginRequestFunc: func(r *http.Request, origin string) bool {
ctx := r.Context()
if alwaysAllow {
return true
}

origin = strings.ToLower(origin)
for _, p := range patterns {
if p.Match(origin) {
return true
}
}

// pre-flight requests do not contain credentials (cookies, HTTP authorization)
// so we return true in all cases here.
if r.Method == http.MethodOptions {
return true
}

var clientID string

// if the client uses client_secret_post auth it will provide its client ID in form data
clientID = r.PostFormValue("client_id")

// if the client uses client_secret_basic auth the client ID will be the username component
if clientID == "" {
clientID, _, _ = r.BasicAuth()
}

// otherwise, this may be a bearer auth request, in which case we can introspect the token
if clientID == "" {
token := fosite.AccessTokenFromRequest(r)
if token == "" {
return false
}

session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims(ctx))
_, ar, err := reg.OAuth2Provider().IntrospectToken(ctx, token, fosite.AccessToken, session)
if err != nil {
return false
}

clientID = ar.GetClient().GetID()
}

cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID)
if err != nil {
return false
}

for _, o := range cl.AllowedCORSOrigins {
if o == "*" {
return true
}

// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if scheme, rest, found := strings.Cut(o, "://"); found && rest == "*" {
o = scheme + "://**"
}

g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
}
if g.Match(origin) {
return true
}
}

for _, o := range cl.AllowedCORSOrigins {
if o == "*" {
return true
}
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
}
if g.Match(origin) {
return true
}
},
}

return false
},
reg.Logger().Debug("enhancing CORS per client")
cors.New(options).Handler(h).ServeHTTP(w, r)
})
}

return cors.New(options).Handler
}
Loading

0 comments on commit a8ecf80

Please sign in to comment.