Skip to content

[public-api-server] Forward Origin header where provided #16405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions components/gitpod-protocol/go/gitpod-service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"

"github.com/sourcegraph/jsonrpc2"
"golang.org/x/xerrors"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -262,7 +260,7 @@ type ConnectToServerOpts struct {
Context context.Context
Token string
Cookie string
NoOrigin bool
Origin string
Log *logrus.Entry
ReconnectionHandler func()
CloseHandler func(error)
Expand All @@ -275,22 +273,8 @@ func ConnectToServer(endpoint string, opts ConnectToServerOpts) (*APIoverJSONRPC
opts.Context = context.Background()
}

epURL, err := url.Parse(endpoint)
if err != nil {
return nil, xerrors.Errorf("invalid endpoint URL: %w", err)
}

reqHeader := http.Header{}
if !opts.NoOrigin {
var protocol string
if epURL.Scheme == "wss:" {
protocol = "https"
} else {
protocol = "http"
}
origin := fmt.Sprintf("%s://%s/", protocol, epURL.Hostname())
reqHeader.Set("Origin", origin)
}
reqHeader.Set("Origin", opts.Origin)

for k, v := range opts.ExtraHeaders {
reqHeader.Set(k, v)
Expand Down
12 changes: 6 additions & 6 deletions components/public-api-server/pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"github.com/bufbuild/connect-go"
)

type AuthInterceptor struct {
type Interceptor struct {
accessToken string
}

func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
Expand All @@ -33,7 +33,7 @@ func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
})
}

func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
conn := next(ctx, s)
Expand All @@ -42,7 +42,7 @@ func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc)
}
}

func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
token, err := tokenFromConn(ctx, conn)
if err != nil {
Expand All @@ -54,7 +54,7 @@ func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc

// NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header
func NewServerInterceptor() connect.Interceptor {
return &AuthInterceptor{}
return &Interceptor{}
}

func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) {
Expand Down Expand Up @@ -91,7 +91,7 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke

// NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header
func NewClientInterceptor(accessToken string) connect.Interceptor {
return &AuthInterceptor{
return &Interceptor{
accessToken: accessToken,
}
}
27 changes: 27 additions & 0 deletions components/public-api-server/pkg/origin/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
)

type contextKey int

const (
originContextKey contextKey = iota
)

func ToContext(ctx context.Context, origin string) context.Context {
return context.WithValue(ctx, originContextKey, origin)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I like this!

}

func FromContext(ctx context.Context) string {
if val, ok := ctx.Value(originContextKey).(string); ok {
return val
}

return ""
}
17 changes: 17 additions & 0 deletions components/public-api-server/pkg/origin/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestToFromContext(t *testing.T) {
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted")
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty")
}
48 changes: 48 additions & 0 deletions components/public-api-server/pkg/origin/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"

"github.com/bufbuild/connect-go"
)

func NewInterceptor() *Interceptor {
return &Interceptor{}
}

type Interceptor struct{}

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
req.Header().Add("Origin", FromContext(ctx))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand what this branch is for: Is this interceptor both usable for servers and clients (e.g., when forwarding upstream)? 🤔

Copy link
Member

@easyCZ easyCZ Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. The way the interceptor is implemented (as per the interceptor inteface) is that you could write a single interceptor for both clients, and servers.

Here, we don't focus on the client version too much, but we do "make it work" by honouring the value that was on the context.

} else {
origin := req.Header().Get("Origin")
ctx = ToContext(ctx, origin)
}

return next(ctx, req)
})
}

func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
conn := next(ctx, s)
conn.RequestHeader().Add("Origin", FromContext(ctx))

return conn
}
}

func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
origin := conn.RequestHeader().Get("Origin")
ctx = ToContext(ctx, origin)

return next(ctx, conn)
}
}
36 changes: 36 additions & 0 deletions components/public-api-server/pkg/origin/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/bufbuild/connect-go"
"github.com/stretchr/testify/require"
)

func TestInterceptor_Unary(t *testing.T) {
requestPaylaod := "request"
origin := "my-origin"

type response struct {
origin string
}

handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
origin := FromContext(ctx)
return connect.NewResponse(&response{origin: origin}), nil
})

ctx := context.Background()
request := connect.NewRequest(&requestPaylaod)
request.Header().Add("Origin", origin)

interceptor := NewInterceptor()
resp, err := interceptor.WrapUnary(handler)(ctx, request)
require.NoError(t, err)
require.Equal(t, &response{origin: origin}, resp.Any())
}
33 changes: 25 additions & 8 deletions components/public-api-server/pkg/proxy/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gitpod-io/gitpod/common-go/log"
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"

lru "github.com/hashicorp/golang-lru"
Expand Down Expand Up @@ -41,6 +42,7 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP
opts := gitpod.ConnectToServerOpts{
Context: ctx,
Log: logger,
Origin: origin.FromContext(ctx),
}

switch token.Type {
Expand Down Expand Up @@ -82,12 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

return &ConnectionPool{
cache: cache,
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
opts := gitpod.ConnectToServerOpts{
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request
Context: context.Background(),
Log: log.Log,
NoOrigin: true,
Context: context.Background(),
Log: log.Log,
Origin: origin.FromContext(ctx),
CloseHandler: func(_ error) {
cache.Remove(token)
connectionPoolSize.Dec()
Expand Down Expand Up @@ -119,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

}

type conenctionPoolCacheKey struct {
token auth.Token
origin string
}

type ConnectionPool struct {
connConstructor func(token auth.Token) (gitpod.APIInterface, error)
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error)

// cache stores token to connection mapping
cache *lru.Cache
}

func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
cached, found := p.cache.Get(token)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

origin := origin.FromContext(ctx)

cacheKey := p.cacheKey(token, origin)
cached, found := p.cache.Get(cacheKey)
reportCacheOutcome(found)
if found {
conn, ok := cached.(*gitpod.APIoverJSONRPC)
Expand All @@ -136,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII
}
}

conn, err := p.connConstructor(token)
conn, err := p.connConstructor(ctx, token)
if err != nil {
return nil, fmt.Errorf("failed to create new connection to server: %w", err)
}

p.cache.Add(token, conn)
p.cache.Add(cacheKey, conn)
connectionPoolSize.Inc()

return conn, nil
}

func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey {
return conenctionPoolCacheKey{
token: token,
origin: origin,
}
}

func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) {
switch t.Type {
case auth.AccessTokenType:
Expand Down
35 changes: 32 additions & 3 deletions components/public-api-server/pkg/proxy/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/golang/mock/gomock"
lru "github.com/hashicorp/golang-lru"
"github.com/stretchr/testify/require"
Expand All @@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) {
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}
Expand All @@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) {
_, err = pool.Get(context.Background(), bazToken)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")
require.True(t, pool.cache.Contains(barToken))
require.True(t, pool.cache.Contains(bazToken))
require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))
require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))
}

func TestConnectionPool_ByDistinctOrigins(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
srv := gitpod.NewMockAPIInterface(ctrl)

cache, err := lru.New(2)
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}

token := auth.NewAccessToken("foo")

ctxWithOriginA := origin.ToContext(context.Background(), "originA")
ctxWithOriginB := origin.ToContext(context.Background(), "originB")

_, err = pool.Get(ctxWithOriginA, token)
require.NoError(t, err)
require.Equal(t, 1, pool.cache.Len())

_, err = pool.Get(ctxWithOriginB, token)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len())
}

func TestEndpointBasedOnToken(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions components/public-api-server/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice"
"github.com/gitpod-io/gitpod/public-api-server/pkg/oidc"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/gitpod-io/gitpod/public-api-server/pkg/proxy"
"github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -154,6 +155,7 @@ func register(srv *baseserver.Server, deps *registerDependencies) error {
NewMetricsInterceptor(connectMetrics),
NewLogInterceptor(log.Log),
auth.NewServerInterceptor(),
origin.NewInterceptor(),
),
}

Expand Down