Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: enable server-side config from context #3954

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cipher/cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"testing"

confighelpers "github.com/ory/kratos/driver/config/testhelpers"

"github.com/ory/x/configx"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -44,7 +46,7 @@ func TestCipher(t *testing.T) {
t.Run("case=encryption_failed", func(t *testing.T) {
t.Parallel()

ctx := config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""})
ctx := confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""})

// secret have to be set
_, err := c.Encrypt(ctx, []byte("not-empty"))
Expand All @@ -53,7 +55,7 @@ func TestCipher(t *testing.T) {
require.ErrorAs(t, err, &hErr)
assert.Equal(t, "Unable to encrypt message because no cipher secrets were configured.", hErr.Reason())

ctx = config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{"bad-length"})
ctx = confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{"bad-length"})

// bad secret length
_, err = c.Encrypt(ctx, []byte("not-empty"))
Expand All @@ -70,7 +72,7 @@ func TestCipher(t *testing.T) {
_, err = c.Decrypt(ctx, "not-empty")
require.Error(t, err)

_, err = c.Decrypt(config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}), "not-empty")
_, err = c.Decrypt(confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}), "not-empty")
require.Error(t, err)
})
})
Expand Down
40 changes: 30 additions & 10 deletions driver/config/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package config_test
import (
"context"
"io"
"net/http/httptest"
"testing"

confighelpers "github.com/ory/kratos/driver/config/testhelpers"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -17,35 +18,54 @@ import (
"github.com/ory/kratos/internal"
)

type configProvider struct {
cfg *config.Config
}

func (c *configProvider) Config() *config.Config {
return c.cfg
}

func TestNewConfigHashHandler(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
cfg := internal.NewConfigurationWithDefaults(t)
router := httprouter.New()
config.NewConfigHashHandler(reg, router)
ts := httptest.NewServer(router)
config.NewConfigHashHandler(&configProvider{cfg: cfg}, router)
ts := confighelpers.NewConfigurableTestServer(router)
t.Cleanup(ts.Close)
res, err := ts.Client().Get(ts.URL + "/health/config")

// first request, get baseline hash
res, err := ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
first, err := io.ReadAll(res.Body)
require.NoError(t, err)

res, err = ts.Client().Get(ts.URL + "/health/config")
// second request, no config change
res, err = ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, first, second)

require.NoError(t, conf.Set(ctx, config.ViperKeySessionDomain, "foobar"))
// third request, with config change
res, err = ts.Client(confighelpers.WithConfigValue(ctx, config.ViperKeySessionDomain, "foobar")).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
third, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, third)

res, err = ts.Client().Get(ts.URL + "/health/config")
// fourth request, no config change
res, err = ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err = io.ReadAll(res.Body)
fourth, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, second)
assert.Equal(t, first, fourth)
}
64 changes: 0 additions & 64 deletions driver/config/test_config.go

This file was deleted.

152 changes: 152 additions & 0 deletions driver/config/testhelpers/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package testhelpers

import (
"context"
"net/http"
"net/http/httptest"

"github.com/gofrs/uuid"

"github.com/ory/kratos/embedx"
"github.com/ory/x/configx"
"github.com/ory/x/contextx"
)

type (
TestConfigProvider struct {
contextx.Contextualizer
Options []configx.OptionModifier
}
contextKey int
)

func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.OptionModifier) (*configx.Provider, error) {
return configx.New(ctx, []byte(embedx.ConfigSchema), append(t.Options, opts...)...)
}

func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provider) *configx.Provider {
config = t.Contextualizer.Config(ctx, config)
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
return config
}
opts := make([]configx.OptionModifier, 0, len(values))
for _, v := range values {
opts = append(opts, configx.WithValues(v))
}
config, err := t.NewProvider(ctx, opts...)
if err != nil {
// This is not production code. The provider is only used in tests.
panic(err)

Check warning on line 43 in driver/config/testhelpers/config.go

View check run for this annotation

Codecov / codecov/patch

driver/config/testhelpers/config.go#L43

Added line #L43 was not covered by tests
}
return config
}

const contextConfigKey contextKey = 1

var (
_ contextx.Contextualizer = (*TestConfigProvider)(nil)
)

func WithConfigValue(ctx context.Context, key string, value any) context.Context {
return WithConfigValues(ctx, map[string]any{key: value})
}

func WithConfigValues(ctx context.Context, setValues ...map[string]any) context.Context {
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
values = make([]map[string]any, 0)
}
newValues := make([]map[string]any, len(values), len(values)+len(setValues))
copy(newValues, values)
newValues = append(newValues, setValues...)

return context.WithValue(ctx, contextConfigKey, newValues)
}

type ConfigurableTestHandler struct {
configs map[uuid.UUID][]map[string]any
handler http.Handler
}

func NewConfigurableTestHandler(h http.Handler) *ConfigurableTestHandler {
return &ConfigurableTestHandler{
configs: make(map[uuid.UUID][]map[string]any),
handler: h,
}
}

func (t *ConfigurableTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cID := r.Header.Get("Test-Config-Id")
if config, ok := t.configs[uuid.FromStringOrNil(cID)]; ok {
r = r.WithContext(WithConfigValues(r.Context(), config...))
}
t.handler.ServeHTTP(w, r)
}

func (t *ConfigurableTestHandler) RegisterConfig(config ...map[string]any) uuid.UUID {
id := uuid.Must(uuid.NewV4())
t.configs[id] = config
return id
}

func (t *ConfigurableTestHandler) UseConfig(r *http.Request, id uuid.UUID) *http.Request {
r.Header.Set("Test-Config-Id", id.String())
return r
}

func (t *ConfigurableTestHandler) UseConfigValues(r *http.Request, values ...map[string]any) *http.Request {
return t.UseConfig(r, t.RegisterConfig(values...))
}

type ConfigurableTestServer struct {
*httptest.Server
handler *ConfigurableTestHandler
transport http.RoundTripper
}

func NewConfigurableTestServer(h http.Handler) *ConfigurableTestServer {
handler := NewConfigurableTestHandler(h)
server := httptest.NewServer(handler)

t := server.Client().Transport
cts := &ConfigurableTestServer{
handler: handler,
Server: server,
transport: t,
}
server.Client().Transport = cts
return cts
}

func (t *ConfigurableTestServer) RoundTrip(r *http.Request) (*http.Response, error) {
config, ok := r.Context().Value(contextConfigKey).([]map[string]any)
if ok && config != nil {
r = t.handler.UseConfigValues(r, config...)
}
return t.transport.RoundTrip(r)
}

type AutoContextClient struct {
*http.Client
transport http.RoundTripper
ctx context.Context
}

func (t *ConfigurableTestServer) Client(ctx context.Context) *AutoContextClient {
baseClient := *t.Server.Client()
autoClient := &AutoContextClient{
Client: &baseClient,
transport: t,
ctx: ctx,
}
baseClient.Transport = autoClient
return autoClient
}

func (c *AutoContextClient) RoundTrip(r *http.Request) (*http.Response, error) {
return c.transport.RoundTrip(r.WithContext(c.ctx))
}
Loading
Loading