-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[public-api-server] Forward Origin header where provided #16405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
) | ||
|
||
type contextKey int | ||
|
||
const ( | ||
originContextKey contextKey = iota | ||
) | ||
|
||
func ToContext(ctx context.Context, origin string) context.Context { | ||
return context.WithValue(ctx, originContextKey, origin) | ||
} | ||
|
||
func FromContext(ctx context.Context) string { | ||
if val, ok := ctx.Value(originContextKey).(string); ok { | ||
return val | ||
} | ||
|
||
return "" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestToFromContext(t *testing.T) { | ||
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted") | ||
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty") | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/bufbuild/connect-go" | ||
) | ||
|
||
func NewInterceptor() *Interceptor { | ||
return &Interceptor{} | ||
} | ||
|
||
type Interceptor struct{} | ||
|
||
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { | ||
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { | ||
if req.Spec().IsClient { | ||
req.Header().Add("Origin", FromContext(ctx)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to understand what this branch is for: Is this interceptor both usable for servers and clients (e.g., when forwarding upstream)? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. The way the interceptor is implemented (as per the interceptor inteface) is that you could write a single interceptor for both clients, and servers. Here, we don't focus on the client version too much, but we do "make it work" by honouring the value that was on the context. |
||
} else { | ||
origin := req.Header().Get("Origin") | ||
ctx = ToContext(ctx, origin) | ||
} | ||
|
||
return next(ctx, req) | ||
}) | ||
} | ||
|
||
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { | ||
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { | ||
conn := next(ctx, s) | ||
conn.RequestHeader().Add("Origin", FromContext(ctx)) | ||
|
||
return conn | ||
} | ||
} | ||
|
||
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { | ||
return func(ctx context.Context, conn connect.StreamingHandlerConn) error { | ||
origin := conn.RequestHeader().Get("Origin") | ||
ctx = ToContext(ctx, origin) | ||
|
||
return next(ctx, conn) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/bufbuild/connect-go" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestInterceptor_Unary(t *testing.T) { | ||
requestPaylaod := "request" | ||
origin := "my-origin" | ||
|
||
type response struct { | ||
origin string | ||
} | ||
|
||
handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { | ||
origin := FromContext(ctx) | ||
return connect.NewResponse(&response{origin: origin}), nil | ||
}) | ||
|
||
ctx := context.Background() | ||
request := connect.NewRequest(&requestPaylaod) | ||
request.Header().Add("Origin", origin) | ||
|
||
interceptor := NewInterceptor() | ||
resp, err := interceptor.WrapUnary(handler)(ctx, request) | ||
require.NoError(t, err) | ||
require.Equal(t, &response{origin: origin}, resp.Any()) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ import ( | |
"github.com/gitpod-io/gitpod/common-go/log" | ||
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol" | ||
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth" | ||
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin" | ||
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" | ||
|
||
lru "github.com/hashicorp/golang-lru" | ||
|
@@ -41,6 +42,7 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP | |
opts := gitpod.ConnectToServerOpts{ | ||
Context: ctx, | ||
Log: logger, | ||
Origin: origin.FromContext(ctx), | ||
} | ||
|
||
switch token.Type { | ||
|
@@ -82,12 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) | |
|
||
return &ConnectionPool{ | ||
cache: cache, | ||
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) { | ||
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { | ||
opts := gitpod.ConnectToServerOpts{ | ||
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request | ||
Context: context.Background(), | ||
Log: log.Log, | ||
NoOrigin: true, | ||
Context: context.Background(), | ||
Log: log.Log, | ||
Origin: origin.FromContext(ctx), | ||
CloseHandler: func(_ error) { | ||
cache.Remove(token) | ||
connectionPoolSize.Dec() | ||
|
@@ -119,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error) | |
|
||
} | ||
|
||
type conenctionPoolCacheKey struct { | ||
token auth.Token | ||
origin string | ||
} | ||
|
||
type ConnectionPool struct { | ||
connConstructor func(token auth.Token) (gitpod.APIInterface, error) | ||
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error) | ||
|
||
// cache stores token to connection mapping | ||
cache *lru.Cache | ||
} | ||
|
||
func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) { | ||
cached, found := p.cache.Get(token) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💯 |
||
origin := origin.FromContext(ctx) | ||
|
||
cacheKey := p.cacheKey(token, origin) | ||
cached, found := p.cache.Get(cacheKey) | ||
reportCacheOutcome(found) | ||
if found { | ||
conn, ok := cached.(*gitpod.APIoverJSONRPC) | ||
|
@@ -136,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII | |
} | ||
} | ||
|
||
conn, err := p.connConstructor(token) | ||
conn, err := p.connConstructor(ctx, token) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create new connection to server: %w", err) | ||
} | ||
|
||
p.cache.Add(token, conn) | ||
p.cache.Add(cacheKey, conn) | ||
connectionPoolSize.Inc() | ||
|
||
return conn, nil | ||
} | ||
|
||
func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey { | ||
return conenctionPoolCacheKey{ | ||
token: token, | ||
origin: origin, | ||
} | ||
} | ||
|
||
func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) { | ||
switch t.Type { | ||
case auth.AccessTokenType: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I like this!