Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add authentication options to hooks #3633

Merged
merged 6 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ const (
KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional"
KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional"
KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl"
KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHookURL = "oauth2.token_hook" // #nosec G101
KeyRefreshTokenHook = "oauth2.refresh_token_hook" // #nosec G101
KeyTokenHook = "oauth2.token_hook" // #nosec G101
KeyDevelopmentMode = "dev"
)

Expand Down Expand Up @@ -467,12 +467,46 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou
return s
}

func (p *DefaultProvider) TokenHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyTokenHookURL, nil)
type HookConfig struct {
URL string `json:"url"`
Headers map[string]string `json:"headers"`
zepatrik marked this conversation as resolved.
Show resolved Hide resolved
}

func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil)
func (p *DefaultProvider) getHookConfig(ctx context.Context, key string) *HookConfig {
if hookURL := p.getProvider(ctx).RequestURIF(key, nil); hookURL != nil {
return &HookConfig{
URL: hookURL.String(),
}
}

var hookConfig *HookConfig
if err := p.getProvider(ctx).Unmarshal(key, &hookConfig); err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
if hookConfig == nil {
return nil
}

// validate URL by parsing it
u, err := url.ParseRequestURI(hookConfig.URL)
if err != nil {
p.l.WithError(errors.WithStack(err)).
Errorf("Configuration value from key %s could not be decoded.", key)
return nil
}
hookConfig.URL = u.String()

return hookConfig
}

func (p *DefaultProvider) TokenHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyTokenHook)
}

func (p *DefaultProvider) TokenRefreshHookConfig(ctx context.Context) *HookConfig {
return p.getHookConfig(ctx, KeyRefreshTokenHook)
}

func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool {
Expand Down
31 changes: 24 additions & 7 deletions driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -425,17 +424,35 @@ func TestCookieSecure(t *testing.T) {
assert.True(t, c.CookieSecure(ctx))
}

func TestTokenRefreshHookURL(t *testing.T) {
func TestHookConfigs(t *testing.T) {
ctx := context.Background()
l := logrusx.New("", "")
l.Logrus().SetOutput(io.Discard)
c := MustNew(context.Background(), l, configx.SkipValidation())

assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "")
assert.EqualValues(t, (*url.URL)(nil), c.TokenRefreshHookURL(ctx))
c.MustSet(ctx, KeyRefreshTokenHookURL, "http://localhost:8080/oauth/token_refresh")
assert.EqualValues(t, "http://localhost:8080/oauth/token_refresh", c.TokenRefreshHookURL(ctx).String())
for key, getFunc := range map[string]func(context.Context) *HookConfig{
KeyRefreshTokenHook: c.TokenRefreshHookConfig,
KeyTokenHook: c.TokenHookConfig,
} {
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "")
assert.Nil(t, getFunc(ctx))
c.MustSet(ctx, key, "http://localhost:8080/hook")
hc := getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook", hc.URL)

c.MustSet(ctx, key, map[string]any{
"url": "http://localhost:8080/hook2",
"headers": map[string]any{
"My-Headers": "my-value",
},
})
hc = getFunc(ctx)
require.NotNil(t, hc)
assert.EqualValues(t, "http://localhost:8080/hook2", hc.URL)
assert.EqualValues(t, "my-value", hc.Headers["My-Headers"])
}
}

func TestJWTBearer(t *testing.T) {
Expand Down
37 changes: 0 additions & 37 deletions go.sum

Large diffs are not rendered by default.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion oauth2/oauth2_auth_code_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func BenchmarkAuthCode(b *testing.B) {
reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
require.NoError(b, err)
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
Expand Down
62 changes: 33 additions & 29 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
ctx := context.Background()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)

publicClient := hydra.NewAPIClient(hydra.NewConfiguration())
Expand Down Expand Up @@ -955,6 +955,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
return func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Bearer"), "secret value")

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
Expand All @@ -981,9 +982,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Headers: map[string]string{"Bearer": "secret value"},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1030,9 +1034,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1070,9 +1074,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1110,9 +1114,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

expectAud := "https://api.ory.sh/"
c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
Expand Down Expand Up @@ -1657,11 +1661,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down Expand Up @@ -1699,11 +1703,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts)
Expand Down Expand Up @@ -1734,11 +1738,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1764,11 +1768,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand All @@ -1794,11 +1798,11 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
defer hs.Close()

if hookType == "legacy" {
conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil)
conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil)
} else {
conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHookURL, nil)
conf.MustSet(ctx, config.KeyTokenHook, hs.URL)
defer conf.MustSet(ctx, config.KeyTokenHook, nil)
}

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
Expand Down
20 changes: 12 additions & 8 deletions oauth2/oauth2_client_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ func TestClientCredentials(t *testing.T) {

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Bearer"), "secret value")

expectedGrantedScopes := []string{"foobar"}
expectedGrantedAudience := []string{"https://api.ory.sh/"}
Expand Down Expand Up @@ -286,9 +287,12 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Headers: map[string]string{"Bearer": "secret value"},
})

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

secret := uuid.New().String()
cl, conf := newCustomClient(t, &hc.Client{
Expand Down Expand Up @@ -316,9 +320,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -340,9 +344,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand All @@ -364,9 +368,9 @@ func TestClientCredentials(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

_, conf := newClient(t)

Expand Down
20 changes: 10 additions & 10 deletions oauth2/oauth2_jwt_bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.EndpointParams = url.Values{"grant_type": {grantType}, "assertion": {token}}
Expand Down Expand Up @@ -429,9 +429,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.AuthStyle = goauth2.AuthStyleInParams
Expand All @@ -457,9 +457,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down Expand Up @@ -492,9 +492,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down Expand Up @@ -527,9 +527,9 @@ func TestJWTBearer(t *testing.T) {
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil)
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
Expand Down
Loading