Skip to content

Commit

Permalink
feat: add authentication options to hooks (#3633)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Sep 25, 2023
1 parent 3615e3d commit 5c8e792
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 115 deletions.
53 changes: 47 additions & 6 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package config

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -102,8 +103,8 @@ const (
KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional"
KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional"
KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl"
KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHookURL = "oauth2.token_hook" // #nosec G101
KeyRefreshTokenHook = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHook = "oauth2.token_hook" // #nosec G101
KeyDevelopmentMode = "dev"
)

Expand Down Expand Up @@ -467,12 +468,52 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou
return s
}

func (p *DefaultProvider) TokenHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyTokenHookURL, nil)
type (
Auth struct {
Type string `json:"type"`
Config json.RawMessage `json:"config"`
}
HookConfig struct {
URL string `json:"url"`
Auth *Auth `json:"auth"`
}
)

func (p *DefaultProvider) getHookConfig(ctx context.Context, key string) *HookConfig {
if hookURL := p.getProvider(ctx).RequestURIF(key, nil); hookURL != nil {
return &HookConfig{
URL: hookURL.String(),
}
}

var hookConfig *HookConfig
if err := p.getProvider(ctx).Unmarshal(key, &hookConfig); err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
if hookConfig == nil {
return nil
}

// validate URL by parsing it
u, err := url.ParseRequestURI(hookConfig.URL)
if err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
hookConfig.URL = u.String()

return hookConfig
}

func (p *DefaultProvider) TokenHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyTokenHook)
}

func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil)
func (p *DefaultProvider) TokenRefreshHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyRefreshTokenHook)
}

func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool {
Expand Down
34 changes: 27 additions & 7 deletions driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ package config

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -425,17 +425,37 @@ func TestCookieSecure(t *testing.T) {
assert.True(t, c.CookieSecure(ctx))
}

func TestTokenRefreshHookURL(t *testing.T) {
func TestHookConfigs(t *testing.T) {
ctx := context.Background()
l := logrusx.New("", "")
l.Logrus().SetOutput(io.Discard)
c := MustNew(context.Background(), l, configx.SkipValidation())

assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "")
assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "http://localhost:8080/oauth/token_refresh")
assert.EqualValues(t, "http://localhost:8080/oauth/token_refresh", c.TokenRefreshHookURL(ctx).String())
for key, getFunc := range map[string]func(context.Context) *HookConfig{
KeyRefreshTokenHook: c.TokenRefreshHookConfig,
KeyTokenHook: c.TokenHookConfig,
} {
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "")
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "http://localhost:8080/hook")
hc := getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook", hc.URL)

c.MustSet(ctx, key, map[string]any{
"url": "http://localhost:8080/hook2",
"auth": map[string]any{
"type": "api_key",
"config": json.RawMessage(`{"in":"header","name":"my-header","value":"my-value"}`),
},
})
hc = getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook2", hc.URL)
assert.EqualValues(t, "api_key", hc.Auth.Type)
assert.JSONEq(t, `{"in":"header","name":"my-header","value":"my-value"}`, string(hc.Auth.Config))
}
}

func TestJWTBearer(t *testing.T) {
Expand Down
37 changes: 0 additions & 37 deletions go.sum

Large diffs are not rendered by default.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion oauth2/oauth2_auth_code_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func BenchmarkAuthCode(b *testing.B) {
reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
require.NoError(b, err)
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
Expand Down
65 changes: 36 additions & 29 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
ctx := context.Background()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)

publicClient := hydra.NewAPIClient(hydra.NewConfiguration())
Expand Down Expand Up @@ -955,6 +955,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
return func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value")

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
Expand All @@ -981,9 +982,15 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Auth: &config.Auth{
Type: "api_key",
Config: json.RawMessage(`{"in": "header", "name": "Authorization", "value": "Bearer secret value"}`),
},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1030,9 +1037,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1070,9 +1077,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1110,9 +1117,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1657,11 +1664,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down Expand Up @@ -1699,11 +1706,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts)
Expand Down Expand Up @@ -1734,11 +1741,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1764,11 +1771,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1794,11 +1801,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down
23 changes: 15 additions & 8 deletions oauth2/oauth2_client_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ func TestClientCredentials(t *testing.T) {

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value")

expectedGrantedScopes := []string{"foobar"}
expectedGrantedAudience := []string{"https://api.ory.sh/"}
Expand Down Expand Up @@ -286,9 +287,15 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Auth: &config.Auth{
Type: "api_key",
Config: json.RawMessage(`{"in": "header", "name": "Authorization", "value": "Bearer secret value"}`),
},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

secret := uuid.New().String()
cl, conf := newCustomClient(t, &hc.Client{
Expand Down Expand Up @@ -316,9 +323,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -340,9 +347,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -364,9 +371,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand Down
Loading

0 comments on commit 5c8e792

Please sign in to comment.