Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: enable server-side config from context #3954

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions driver/config/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package config_test
import (
"context"
"io"
"net/http/httptest"
"net/http"
"testing"

"github.com/julienschmidt/httprouter"
Expand All @@ -17,35 +17,54 @@ import (
"github.com/ory/kratos/internal"
)

type configProvider struct {
cfg *config.Config
}

func (c *configProvider) Config() *config.Config {
return c.cfg
}

func TestNewConfigHashHandler(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
cfg := internal.NewConfigurationWithDefaults(t)
router := httprouter.New()
config.NewConfigHashHandler(reg, router)
ts := httptest.NewServer(router)
config.NewConfigHashHandler(&configProvider{cfg: cfg}, router)
ts := config.NewConfigurableTestServer(router)
t.Cleanup(ts.Close)
res, err := ts.Client().Get(ts.URL + "/health/config")

// first request, get baseline hash
res, err := ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
first, err := io.ReadAll(res.Body)
require.NoError(t, err)

res, err = ts.Client().Get(ts.URL + "/health/config")
// second request, no config change
res, err = ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, first, second)

require.NoError(t, conf.Set(ctx, config.ViperKeySessionDomain, "foobar"))
// third request, with config change
res, err = ts.Client(config.WithConfigValue(ctx, config.ViperKeySessionDomain, "foobar")).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
third, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, third)

res, err = ts.Client().Get(ts.URL + "/health/config")
// fourth request, no config change
res, err = http.Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err = io.ReadAll(res.Body)
fourth, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, second)
assert.Equal(t, first, fourth)
}
94 changes: 91 additions & 3 deletions driver/config/test_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ package config

import (
"context"
"net/http"
"net/http/httptest"

"github.com/gofrs/uuid"

"github.com/ory/kratos/embedx"
"github.com/ory/x/configx"
Expand Down Expand Up @@ -51,14 +55,98 @@ func WithConfigValue(ctx context.Context, key string, value any) context.Context
return WithConfigValues(ctx, map[string]any{key: value})
}

func WithConfigValues(ctx context.Context, setValues map[string]any) context.Context {
func WithConfigValues(ctx context.Context, setValues ...map[string]any) context.Context {
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
values = make([]map[string]any, 0)
}
newValues := make([]map[string]any, len(values), len(values)+1)
newValues := make([]map[string]any, len(values), len(values)+len(setValues))
copy(newValues, values)
newValues = append(newValues, setValues)
newValues = append(newValues, setValues...)

return context.WithValue(ctx, contextConfigKey, newValues)
}

type ConfigurableTestHandler struct {
configs map[uuid.UUID][]map[string]any
handler http.Handler
}

func NewConfigurableTestHandler(h http.Handler) *ConfigurableTestHandler {
return &ConfigurableTestHandler{
configs: make(map[uuid.UUID][]map[string]any),
handler: h,
}
}

func (t *ConfigurableTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cID := r.Header.Get("Test-Config-Id")
if config, ok := t.configs[uuid.FromStringOrNil(cID)]; ok {
r = r.WithContext(WithConfigValues(r.Context(), config...))
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handler looks up the config ID from the request header and adds the config to the request context server-side if it found it.

t.handler.ServeHTTP(w, r)
}

func (t *ConfigurableTestHandler) RegisterConfig(config ...map[string]any) uuid.UUID {
id := uuid.Must(uuid.NewV4())
t.configs[id] = config
return id
}

func (t *ConfigurableTestHandler) UseConfig(r *http.Request, id uuid.UUID) *http.Request {
r.Header.Set("Test-Config-Id", id.String())
return r
}

func (t *ConfigurableTestHandler) UseConfigValues(r *http.Request, values ...map[string]any) *http.Request {
return t.UseConfig(r, t.RegisterConfig(values...))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering the config values and setting the config header on the request (client side).

}

type ConfigurableTestServer struct {
*httptest.Server
handler *ConfigurableTestHandler
transport http.RoundTripper
}

func NewConfigurableTestServer(h http.Handler) *ConfigurableTestServer {
handler := NewConfigurableTestHandler(h)
server := httptest.NewServer(handler)

t := server.Client().Transport
cts := &ConfigurableTestServer{
handler: handler,
Server: server,
transport: t,
}
server.Client().Transport = cts
return cts
}

func (t *ConfigurableTestServer) RoundTrip(r *http.Request) (*http.Response, error) {
config, ok := r.Context().Value(contextConfigKey).([]map[string]any)
if ok && config != nil {
r = t.handler.UseConfigValues(r, config...)
}
return t.transport.RoundTrip(r)
}

type AutoContextClient struct {
*http.Client
transport http.RoundTripper
ctx context.Context
}

func (t *ConfigurableTestServer) Client(ctx context.Context) *AutoContextClient {
baseClient := *t.Server.Client()
autoClient := &AutoContextClient{
Client: &baseClient,
transport: t,
ctx: ctx,
}
baseClient.Transport = autoClient
return autoClient
}

func (c *AutoContextClient) RoundTrip(r *http.Request) (*http.Response, error) {
return c.transport.RoundTrip(r.WithContext(c.ctx))
}
Loading