From 0a7ca4c08afc83dc6e801198f8f979a6a7c67d1f Mon Sep 17 00:00:00 2001 From: Milan Pavlik Date: Wed, 15 Feb 2023 12:42:52 +0000 Subject: [PATCH] Use context to store and populate origin --- .../public-api-server/pkg/auth/context.go | 9 ++-- .../pkg/auth/context_test.go | 2 +- .../public-api-server/pkg/auth/middleware.go | 18 ++++--- .../public-api-server/pkg/origin/context.go | 27 +++++++++++ .../pkg/origin/context_test.go | 17 +++++++ .../pkg/origin/middleware.go | 48 +++++++++++++++++++ .../pkg/origin/middleware_test.go | 36 ++++++++++++++ .../public-api-server/pkg/proxy/conn.go | 30 +++++++++--- .../public-api-server/pkg/proxy/conn_test.go | 37 ++++++++++++-- .../public-api-server/pkg/server/server.go | 2 + 10 files changed, 198 insertions(+), 28 deletions(-) create mode 100644 components/public-api-server/pkg/origin/context.go create mode 100644 components/public-api-server/pkg/origin/context_test.go create mode 100644 components/public-api-server/pkg/origin/middleware.go create mode 100644 components/public-api-server/pkg/origin/middleware_test.go diff --git a/components/public-api-server/pkg/auth/context.go b/components/public-api-server/pkg/auth/context.go index 16f7c5b1d0093d..1328f08103c9e8 100644 --- a/components/public-api-server/pkg/auth/context.go +++ b/components/public-api-server/pkg/auth/context.go @@ -25,8 +25,6 @@ const ( type Token struct { Type TokenType Value string - // Only relevant for CookieTokenType - OriginHeader string } func NewAccessToken(token string) Token { @@ -36,11 +34,10 @@ func NewAccessToken(token string) Token { } } -func NewCookieToken(cookie string, origin string) Token { +func NewCookieToken(cookie string) Token { return Token{ - Type: CookieTokenType, - Value: cookie, - OriginHeader: origin, + Type: CookieTokenType, + Value: cookie, } } diff --git a/components/public-api-server/pkg/auth/context_test.go b/components/public-api-server/pkg/auth/context_test.go index ebb68acbf2dba0..159d852771744f 100644 --- a/components/public-api-server/pkg/auth/context_test.go +++ b/components/public-api-server/pkg/auth/context_test.go @@ -20,7 +20,7 @@ func TestTokenToAndFromContext_AccessToken(t *testing.T) { } func TestTokenToAndFromContext_CookieToken(t *testing.T) { - token := NewCookieToken("my_token", "gitpod.io") + token := NewCookieToken("my_token") extracted, err := TokenFromContext(TokenToContext(context.Background(), token)) require.NoError(t, err) diff --git a/components/public-api-server/pkg/auth/middleware.go b/components/public-api-server/pkg/auth/middleware.go index f5162ea6595a8c..1a953e73eb3010 100644 --- a/components/public-api-server/pkg/auth/middleware.go +++ b/components/public-api-server/pkg/auth/middleware.go @@ -11,11 +11,11 @@ import ( "github.com/bufbuild/connect-go" ) -type AuthInterceptor struct { +type Interceptor struct { accessToken string } -func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { +func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { if req.Spec().IsClient { ctx = TokenToContext(ctx, NewAccessToken(a.accessToken)) @@ -33,7 +33,7 @@ func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { }) } -func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { ctx = TokenToContext(ctx, NewAccessToken(a.accessToken)) conn := next(ctx, s) @@ -42,7 +42,7 @@ func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) } } -func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { token, err := tokenFromConn(ctx, conn) if err != nil { @@ -54,7 +54,7 @@ func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc // NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header func NewServerInterceptor() connect.Interceptor { - return &AuthInterceptor{} + return &Interceptor{} } func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) { @@ -66,9 +66,8 @@ func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error } cookie := req.Header().Get("Cookie") - origin := req.Header().Get("Origin") if cookie != "" { - return NewCookieToken(cookie, origin), nil + return NewCookieToken(cookie), nil } return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request.")) @@ -83,9 +82,8 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke } cookie := conn.RequestHeader().Get("Cookie") - origin := conn.RequestHeader().Get("Origin") if cookie != "" { - return NewCookieToken(cookie, origin), nil + return NewCookieToken(cookie), nil } return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request.")) @@ -93,7 +91,7 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke // NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header func NewClientInterceptor(accessToken string) connect.Interceptor { - return &AuthInterceptor{ + return &Interceptor{ accessToken: accessToken, } } diff --git a/components/public-api-server/pkg/origin/context.go b/components/public-api-server/pkg/origin/context.go new file mode 100644 index 00000000000000..f4028cb8140b42 --- /dev/null +++ b/components/public-api-server/pkg/origin/context.go @@ -0,0 +1,27 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" +) + +type contextKey int + +const ( + originContextKey contextKey = iota +) + +func ToContext(ctx context.Context, origin string) context.Context { + return context.WithValue(ctx, originContextKey, origin) +} + +func FromContext(ctx context.Context) string { + if val, ok := ctx.Value(originContextKey).(string); ok { + return val + } + + return "" +} diff --git a/components/public-api-server/pkg/origin/context_test.go b/components/public-api-server/pkg/origin/context_test.go new file mode 100644 index 00000000000000..358b7546fc4ae8 --- /dev/null +++ b/components/public-api-server/pkg/origin/context_test.go @@ -0,0 +1,17 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToFromContext(t *testing.T) { + require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted") + require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty") +} diff --git a/components/public-api-server/pkg/origin/middleware.go b/components/public-api-server/pkg/origin/middleware.go new file mode 100644 index 00000000000000..2f0018074fa698 --- /dev/null +++ b/components/public-api-server/pkg/origin/middleware.go @@ -0,0 +1,48 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + + "github.com/bufbuild/connect-go" +) + +func NewInterceptor() *Interceptor { + return &Interceptor{} +} + +type Interceptor struct{} + +func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if req.Spec().IsClient { + req.Header().Add("Origin", FromContext(ctx)) + } else { + origin := req.Header().Get("Origin") + ctx = ToContext(ctx, origin) + } + + return next(ctx, req) + }) +} + +func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { + conn := next(ctx, s) + conn.RequestHeader().Add("Origin", FromContext(ctx)) + + return conn + } +} + +func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + origin := conn.RequestHeader().Get("Origin") + ctx = ToContext(ctx, origin) + + return next(ctx, conn) + } +} diff --git a/components/public-api-server/pkg/origin/middleware_test.go b/components/public-api-server/pkg/origin/middleware_test.go new file mode 100644 index 00000000000000..e8ce434cbeb99e --- /dev/null +++ b/components/public-api-server/pkg/origin/middleware_test.go @@ -0,0 +1,36 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + "testing" + + "github.com/bufbuild/connect-go" + "github.com/stretchr/testify/require" +) + +func TestInterceptor_Unary(t *testing.T) { + requestPaylaod := "request" + origin := "my-origin" + + type response struct { + origin string + } + + handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { + origin := FromContext(ctx) + return connect.NewResponse(&response{origin: origin}), nil + }) + + ctx := context.Background() + request := connect.NewRequest(&requestPaylaod) + request.Header().Add("Origin", origin) + + interceptor := NewInterceptor() + resp, err := interceptor.WrapUnary(handler)(ctx, request) + require.NoError(t, err) + require.Equal(t, &response{origin: origin}, resp.Any()) +} diff --git a/components/public-api-server/pkg/proxy/conn.go b/components/public-api-server/pkg/proxy/conn.go index d59e5985770ae7..2dc5604445eb4e 100644 --- a/components/public-api-server/pkg/proxy/conn.go +++ b/components/public-api-server/pkg/proxy/conn.go @@ -14,6 +14,7 @@ import ( "github.com/gitpod-io/gitpod/common-go/log" gitpod "github.com/gitpod-io/gitpod/gitpod-protocol" "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" lru "github.com/hashicorp/golang-lru" @@ -41,6 +42,7 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP opts := gitpod.ConnectToServerOpts{ Context: ctx, Log: logger, + Origin: origin.FromContext(ctx), } switch token.Type { @@ -48,7 +50,6 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP opts.Token = token.Value case auth.CookieTokenType: opts.Cookie = token.Value - opts.Origin = token.OriginHeader default: return nil, errors.New("unknown token type") } @@ -83,11 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) return &ConnectionPool{ cache: cache, - connConstructor: func(token auth.Token) (gitpod.APIInterface, error) { + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { opts := gitpod.ConnectToServerOpts{ // We're using Background context as we want the connection to persist beyond the lifecycle of a single request Context: context.Background(), Log: log.Log, + Origin: origin.FromContext(ctx), CloseHandler: func(_ error) { cache.Remove(token) connectionPoolSize.Dec() @@ -99,7 +101,6 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) opts.Token = token.Value case auth.CookieTokenType: opts.Cookie = token.Value - opts.Origin = token.OriginHeader default: return nil, errors.New("unknown token type") } @@ -120,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) } +type conenctionPoolCacheKey struct { + token auth.Token + origin string +} + type ConnectionPool struct { - connConstructor func(token auth.Token) (gitpod.APIInterface, error) + connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error) // cache stores token to connection mapping cache *lru.Cache } func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { - cached, found := p.cache.Get(token) + origin := origin.FromContext(ctx) + + cacheKey := p.cacheKey(token, origin) + cached, found := p.cache.Get(cacheKey) reportCacheOutcome(found) if found { conn, ok := cached.(*gitpod.APIoverJSONRPC) @@ -137,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII } } - conn, err := p.connConstructor(token) + conn, err := p.connConstructor(ctx, token) if err != nil { return nil, fmt.Errorf("failed to create new connection to server: %w", err) } - p.cache.Add(token, conn) + p.cache.Add(cacheKey, conn) connectionPoolSize.Inc() return conn, nil } +func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey { + return conenctionPoolCacheKey{ + token: token, + origin: origin, + } +} + func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) { switch t.Type { case auth.AccessTokenType: diff --git a/components/public-api-server/pkg/proxy/conn_test.go b/components/public-api-server/pkg/proxy/conn_test.go index 6eb488c698101c..0eaebb5032618e 100644 --- a/components/public-api-server/pkg/proxy/conn_test.go +++ b/components/public-api-server/pkg/proxy/conn_test.go @@ -11,6 +11,7 @@ import ( gitpod "github.com/gitpod-io/gitpod/gitpod-protocol" "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/golang/mock/gomock" lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/require" @@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) { require.NoError(t, err) pool := &ConnectionPool{ cache: cache, - connConstructor: func(token auth.Token) (gitpod.APIInterface, error) { + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { return srv, nil }, } @@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) { _, err = pool.Get(context.Background(), bazToken) require.NoError(t, err) require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons") - require.True(t, pool.cache.Contains(barToken)) - require.True(t, pool.cache.Contains(bazToken)) + require.True(t, pool.cache.Contains(pool.cacheKey(barToken, ""))) + require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, ""))) +} + +func TestConnectionPool_ByDistinctOrigins(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + srv := gitpod.NewMockAPIInterface(ctrl) + + cache, err := lru.New(2) + require.NoError(t, err) + pool := &ConnectionPool{ + cache: cache, + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { + return srv, nil + }, + } + + token := auth.NewAccessToken("foo") + + ctxWithOriginA := origin.ToContext(context.Background(), "originA") + ctxWithOriginB := origin.ToContext(context.Background(), "originB") + + _, err = pool.Get(ctxWithOriginA, token) + require.NoError(t, err) + require.Equal(t, 1, pool.cache.Len()) + + _, err = pool.Get(ctxWithOriginB, token) + require.NoError(t, err) + require.Equal(t, 2, pool.cache.Len()) } func TestEndpointBasedOnToken(t *testing.T) { @@ -57,7 +86,7 @@ func TestEndpointBasedOnToken(t *testing.T) { require.NoError(t, err) require.Equal(t, "wss://server:3000/v1", endpointForAccessToken) - endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo", "server"), u) + endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u) require.NoError(t, err) require.Equal(t, "wss://server:3000/gitpod", endpointForCookie) } diff --git a/components/public-api-server/pkg/server/server.go b/components/public-api-server/pkg/server/server.go index 18367800c4d8d1..3338088f5f8c8e 100644 --- a/components/public-api-server/pkg/server/server.go +++ b/components/public-api-server/pkg/server/server.go @@ -28,6 +28,7 @@ import ( "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" "github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice" "github.com/gitpod-io/gitpod/public-api-server/pkg/oidc" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/gitpod-io/gitpod/public-api-server/pkg/proxy" "github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks" "github.com/sirupsen/logrus" @@ -154,6 +155,7 @@ func register(srv *baseserver.Server, deps *registerDependencies) error { NewMetricsInterceptor(connectMetrics), NewLogInterceptor(log.Log), auth.NewServerInterceptor(), + origin.NewInterceptor(), ), }