diff --git a/components/gitpod-protocol/go/gitpod-service.go b/components/gitpod-protocol/go/gitpod-service.go index 441f07197b0163..09b7f01a12643d 100644 --- a/components/gitpod-protocol/go/gitpod-service.go +++ b/components/gitpod-protocol/go/gitpod-service.go @@ -13,12 +13,10 @@ import ( "fmt" "io" "net/http" - "net/url" "sync" "time" "github.com/sourcegraph/jsonrpc2" - "golang.org/x/xerrors" "github.com/sirupsen/logrus" ) @@ -262,7 +260,7 @@ type ConnectToServerOpts struct { Context context.Context Token string Cookie string - NoOrigin bool + Origin string Log *logrus.Entry ReconnectionHandler func() CloseHandler func(error) @@ -275,22 +273,8 @@ func ConnectToServer(endpoint string, opts ConnectToServerOpts) (*APIoverJSONRPC opts.Context = context.Background() } - epURL, err := url.Parse(endpoint) - if err != nil { - return nil, xerrors.Errorf("invalid endpoint URL: %w", err) - } - reqHeader := http.Header{} - if !opts.NoOrigin { - var protocol string - if epURL.Scheme == "wss:" { - protocol = "https" - } else { - protocol = "http" - } - origin := fmt.Sprintf("%s://%s/", protocol, epURL.Hostname()) - reqHeader.Set("Origin", origin) - } + reqHeader.Set("Origin", opts.Origin) for k, v := range opts.ExtraHeaders { reqHeader.Set(k, v) diff --git a/components/public-api-server/pkg/auth/middleware.go b/components/public-api-server/pkg/auth/middleware.go index 5a037c7d3e7cdc..1a953e73eb3010 100644 --- a/components/public-api-server/pkg/auth/middleware.go +++ b/components/public-api-server/pkg/auth/middleware.go @@ -11,11 +11,11 @@ import ( "github.com/bufbuild/connect-go" ) -type AuthInterceptor struct { +type Interceptor struct { accessToken string } -func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { +func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { if req.Spec().IsClient { ctx = TokenToContext(ctx, NewAccessToken(a.accessToken)) @@ -33,7 +33,7 @@ func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { }) } -func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { ctx = TokenToContext(ctx, NewAccessToken(a.accessToken)) conn := next(ctx, s) @@ -42,7 +42,7 @@ func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) } } -func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { token, err := tokenFromConn(ctx, conn) if err != nil { @@ -54,7 +54,7 @@ func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc // NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header func NewServerInterceptor() connect.Interceptor { - return &AuthInterceptor{} + return &Interceptor{} } func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) { @@ -91,7 +91,7 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke // NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header func NewClientInterceptor(accessToken string) connect.Interceptor { - return &AuthInterceptor{ + return &Interceptor{ accessToken: accessToken, } } diff --git a/components/public-api-server/pkg/origin/context.go b/components/public-api-server/pkg/origin/context.go new file mode 100644 index 00000000000000..f4028cb8140b42 --- /dev/null +++ b/components/public-api-server/pkg/origin/context.go @@ -0,0 +1,27 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" +) + +type contextKey int + +const ( + originContextKey contextKey = iota +) + +func ToContext(ctx context.Context, origin string) context.Context { + return context.WithValue(ctx, originContextKey, origin) +} + +func FromContext(ctx context.Context) string { + if val, ok := ctx.Value(originContextKey).(string); ok { + return val + } + + return "" +} diff --git a/components/public-api-server/pkg/origin/context_test.go b/components/public-api-server/pkg/origin/context_test.go new file mode 100644 index 00000000000000..358b7546fc4ae8 --- /dev/null +++ b/components/public-api-server/pkg/origin/context_test.go @@ -0,0 +1,17 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToFromContext(t *testing.T) { + require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted") + require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty") +} diff --git a/components/public-api-server/pkg/origin/middleware.go b/components/public-api-server/pkg/origin/middleware.go new file mode 100644 index 00000000000000..2f0018074fa698 --- /dev/null +++ b/components/public-api-server/pkg/origin/middleware.go @@ -0,0 +1,48 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + + "github.com/bufbuild/connect-go" +) + +func NewInterceptor() *Interceptor { + return &Interceptor{} +} + +type Interceptor struct{} + +func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if req.Spec().IsClient { + req.Header().Add("Origin", FromContext(ctx)) + } else { + origin := req.Header().Get("Origin") + ctx = ToContext(ctx, origin) + } + + return next(ctx, req) + }) +} + +func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { + conn := next(ctx, s) + conn.RequestHeader().Add("Origin", FromContext(ctx)) + + return conn + } +} + +func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + origin := conn.RequestHeader().Get("Origin") + ctx = ToContext(ctx, origin) + + return next(ctx, conn) + } +} diff --git a/components/public-api-server/pkg/origin/middleware_test.go b/components/public-api-server/pkg/origin/middleware_test.go new file mode 100644 index 00000000000000..e8ce434cbeb99e --- /dev/null +++ b/components/public-api-server/pkg/origin/middleware_test.go @@ -0,0 +1,36 @@ +// Copyright (c) 2023 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License.AGPL.txt in the project root for license information. + +package origin + +import ( + "context" + "testing" + + "github.com/bufbuild/connect-go" + "github.com/stretchr/testify/require" +) + +func TestInterceptor_Unary(t *testing.T) { + requestPaylaod := "request" + origin := "my-origin" + + type response struct { + origin string + } + + handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { + origin := FromContext(ctx) + return connect.NewResponse(&response{origin: origin}), nil + }) + + ctx := context.Background() + request := connect.NewRequest(&requestPaylaod) + request.Header().Add("Origin", origin) + + interceptor := NewInterceptor() + resp, err := interceptor.WrapUnary(handler)(ctx, request) + require.NoError(t, err) + require.Equal(t, &response{origin: origin}, resp.Any()) +} diff --git a/components/public-api-server/pkg/proxy/conn.go b/components/public-api-server/pkg/proxy/conn.go index 1248d637da5ebd..2dc5604445eb4e 100644 --- a/components/public-api-server/pkg/proxy/conn.go +++ b/components/public-api-server/pkg/proxy/conn.go @@ -14,6 +14,7 @@ import ( "github.com/gitpod-io/gitpod/common-go/log" gitpod "github.com/gitpod-io/gitpod/gitpod-protocol" "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" lru "github.com/hashicorp/golang-lru" @@ -41,6 +42,7 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP opts := gitpod.ConnectToServerOpts{ Context: ctx, Log: logger, + Origin: origin.FromContext(ctx), } switch token.Type { @@ -82,12 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) return &ConnectionPool{ cache: cache, - connConstructor: func(token auth.Token) (gitpod.APIInterface, error) { + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { opts := gitpod.ConnectToServerOpts{ // We're using Background context as we want the connection to persist beyond the lifecycle of a single request - Context: context.Background(), - Log: log.Log, - NoOrigin: true, + Context: context.Background(), + Log: log.Log, + Origin: origin.FromContext(ctx), CloseHandler: func(_ error) { cache.Remove(token) connectionPoolSize.Dec() @@ -119,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) } +type conenctionPoolCacheKey struct { + token auth.Token + origin string +} + type ConnectionPool struct { - connConstructor func(token auth.Token) (gitpod.APIInterface, error) + connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error) // cache stores token to connection mapping cache *lru.Cache } func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { - cached, found := p.cache.Get(token) + origin := origin.FromContext(ctx) + + cacheKey := p.cacheKey(token, origin) + cached, found := p.cache.Get(cacheKey) reportCacheOutcome(found) if found { conn, ok := cached.(*gitpod.APIoverJSONRPC) @@ -136,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII } } - conn, err := p.connConstructor(token) + conn, err := p.connConstructor(ctx, token) if err != nil { return nil, fmt.Errorf("failed to create new connection to server: %w", err) } - p.cache.Add(token, conn) + p.cache.Add(cacheKey, conn) connectionPoolSize.Inc() return conn, nil } +func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey { + return conenctionPoolCacheKey{ + token: token, + origin: origin, + } +} + func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) { switch t.Type { case auth.AccessTokenType: diff --git a/components/public-api-server/pkg/proxy/conn_test.go b/components/public-api-server/pkg/proxy/conn_test.go index 85825a02955c63..0eaebb5032618e 100644 --- a/components/public-api-server/pkg/proxy/conn_test.go +++ b/components/public-api-server/pkg/proxy/conn_test.go @@ -11,6 +11,7 @@ import ( gitpod "github.com/gitpod-io/gitpod/gitpod-protocol" "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/golang/mock/gomock" lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/require" @@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) { require.NoError(t, err) pool := &ConnectionPool{ cache: cache, - connConstructor: func(token auth.Token) (gitpod.APIInterface, error) { + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { return srv, nil }, } @@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) { _, err = pool.Get(context.Background(), bazToken) require.NoError(t, err) require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons") - require.True(t, pool.cache.Contains(barToken)) - require.True(t, pool.cache.Contains(bazToken)) + require.True(t, pool.cache.Contains(pool.cacheKey(barToken, ""))) + require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, ""))) +} + +func TestConnectionPool_ByDistinctOrigins(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + srv := gitpod.NewMockAPIInterface(ctrl) + + cache, err := lru.New(2) + require.NoError(t, err) + pool := &ConnectionPool{ + cache: cache, + connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { + return srv, nil + }, + } + + token := auth.NewAccessToken("foo") + + ctxWithOriginA := origin.ToContext(context.Background(), "originA") + ctxWithOriginB := origin.ToContext(context.Background(), "originB") + + _, err = pool.Get(ctxWithOriginA, token) + require.NoError(t, err) + require.Equal(t, 1, pool.cache.Len()) + + _, err = pool.Get(ctxWithOriginB, token) + require.NoError(t, err) + require.Equal(t, 2, pool.cache.Len()) } func TestEndpointBasedOnToken(t *testing.T) { diff --git a/components/public-api-server/pkg/server/server.go b/components/public-api-server/pkg/server/server.go index 18367800c4d8d1..3338088f5f8c8e 100644 --- a/components/public-api-server/pkg/server/server.go +++ b/components/public-api-server/pkg/server/server.go @@ -28,6 +28,7 @@ import ( "github.com/gitpod-io/gitpod/public-api-server/pkg/auth" "github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice" "github.com/gitpod-io/gitpod/public-api-server/pkg/oidc" + "github.com/gitpod-io/gitpod/public-api-server/pkg/origin" "github.com/gitpod-io/gitpod/public-api-server/pkg/proxy" "github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks" "github.com/sirupsen/logrus" @@ -154,6 +155,7 @@ func register(srv *baseserver.Server, deps *registerDependencies) error { NewMetricsInterceptor(connectMetrics), NewLogInterceptor(log.Log), auth.NewServerInterceptor(), + origin.NewInterceptor(), ), }