Skip to content

Commit

Permalink
Use context to store and populate origin
Browse files Browse the repository at this point in the history
  • Loading branch information
easyCZ authored and roboquat committed Feb 15, 2023
1 parent 1a90947 commit 0a7ca4c
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 28 deletions.
9 changes: 3 additions & 6 deletions components/public-api-server/pkg/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ const (
type Token struct {
Type TokenType
Value string
// Only relevant for CookieTokenType
OriginHeader string
}

func NewAccessToken(token string) Token {
Expand All @@ -36,11 +34,10 @@ func NewAccessToken(token string) Token {
}
}

func NewCookieToken(cookie string, origin string) Token {
func NewCookieToken(cookie string) Token {
return Token{
Type: CookieTokenType,
Value: cookie,
OriginHeader: origin,
Type: CookieTokenType,
Value: cookie,
}
}

Expand Down
2 changes: 1 addition & 1 deletion components/public-api-server/pkg/auth/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestTokenToAndFromContext_AccessToken(t *testing.T) {
}

func TestTokenToAndFromContext_CookieToken(t *testing.T) {
token := NewCookieToken("my_token", "gitpod.io")
token := NewCookieToken("my_token")

extracted, err := TokenFromContext(TokenToContext(context.Background(), token))
require.NoError(t, err)
Expand Down
18 changes: 8 additions & 10 deletions components/public-api-server/pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"github.com/bufbuild/connect-go"
)

type AuthInterceptor struct {
type Interceptor struct {
accessToken string
}

func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
Expand All @@ -33,7 +33,7 @@ func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
})
}

func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
conn := next(ctx, s)
Expand All @@ -42,7 +42,7 @@ func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc)
}
}

func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
token, err := tokenFromConn(ctx, conn)
if err != nil {
Expand All @@ -54,7 +54,7 @@ func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc

// NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header
func NewServerInterceptor() connect.Interceptor {
return &AuthInterceptor{}
return &Interceptor{}
}

func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) {
Expand All @@ -66,9 +66,8 @@ func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error
}

cookie := req.Header().Get("Cookie")
origin := req.Header().Get("Origin")
if cookie != "" {
return NewCookieToken(cookie, origin), nil
return NewCookieToken(cookie), nil
}

return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
Expand All @@ -83,17 +82,16 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke
}

cookie := conn.RequestHeader().Get("Cookie")
origin := conn.RequestHeader().Get("Origin")
if cookie != "" {
return NewCookieToken(cookie, origin), nil
return NewCookieToken(cookie), nil
}

return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
}

// NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header
func NewClientInterceptor(accessToken string) connect.Interceptor {
return &AuthInterceptor{
return &Interceptor{
accessToken: accessToken,
}
}
27 changes: 27 additions & 0 deletions components/public-api-server/pkg/origin/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
)

type contextKey int

const (
originContextKey contextKey = iota
)

func ToContext(ctx context.Context, origin string) context.Context {
return context.WithValue(ctx, originContextKey, origin)
}

func FromContext(ctx context.Context) string {
if val, ok := ctx.Value(originContextKey).(string); ok {
return val
}

return ""
}
17 changes: 17 additions & 0 deletions components/public-api-server/pkg/origin/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestToFromContext(t *testing.T) {
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted")
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty")
}
48 changes: 48 additions & 0 deletions components/public-api-server/pkg/origin/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"

"github.com/bufbuild/connect-go"
)

func NewInterceptor() *Interceptor {
return &Interceptor{}
}

type Interceptor struct{}

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
req.Header().Add("Origin", FromContext(ctx))
} else {
origin := req.Header().Get("Origin")
ctx = ToContext(ctx, origin)
}

return next(ctx, req)
})
}

func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
conn := next(ctx, s)
conn.RequestHeader().Add("Origin", FromContext(ctx))

return conn
}
}

func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
origin := conn.RequestHeader().Get("Origin")
ctx = ToContext(ctx, origin)

return next(ctx, conn)
}
}
36 changes: 36 additions & 0 deletions components/public-api-server/pkg/origin/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/bufbuild/connect-go"
"github.com/stretchr/testify/require"
)

func TestInterceptor_Unary(t *testing.T) {
requestPaylaod := "request"
origin := "my-origin"

type response struct {
origin string
}

handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
origin := FromContext(ctx)
return connect.NewResponse(&response{origin: origin}), nil
})

ctx := context.Background()
request := connect.NewRequest(&requestPaylaod)
request.Header().Add("Origin", origin)

interceptor := NewInterceptor()
resp, err := interceptor.WrapUnary(handler)(ctx, request)
require.NoError(t, err)
require.Equal(t, &response{origin: origin}, resp.Any())
}
30 changes: 23 additions & 7 deletions components/public-api-server/pkg/proxy/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gitpod-io/gitpod/common-go/log"
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"

lru "github.com/hashicorp/golang-lru"
Expand Down Expand Up @@ -41,14 +42,14 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP
opts := gitpod.ConnectToServerOpts{
Context: ctx,
Log: logger,
Origin: origin.FromContext(ctx),
}

switch token.Type {
case auth.AccessTokenType:
opts.Token = token.Value
case auth.CookieTokenType:
opts.Cookie = token.Value
opts.Origin = token.OriginHeader
default:
return nil, errors.New("unknown token type")
}
Expand Down Expand Up @@ -83,11 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

return &ConnectionPool{
cache: cache,
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
opts := gitpod.ConnectToServerOpts{
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request
Context: context.Background(),
Log: log.Log,
Origin: origin.FromContext(ctx),
CloseHandler: func(_ error) {
cache.Remove(token)
connectionPoolSize.Dec()
Expand All @@ -99,7 +101,6 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
opts.Token = token.Value
case auth.CookieTokenType:
opts.Cookie = token.Value
opts.Origin = token.OriginHeader
default:
return nil, errors.New("unknown token type")
}
Expand All @@ -120,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

}

type conenctionPoolCacheKey struct {
token auth.Token
origin string
}

type ConnectionPool struct {
connConstructor func(token auth.Token) (gitpod.APIInterface, error)
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error)

// cache stores token to connection mapping
cache *lru.Cache
}

func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
cached, found := p.cache.Get(token)
origin := origin.FromContext(ctx)

cacheKey := p.cacheKey(token, origin)
cached, found := p.cache.Get(cacheKey)
reportCacheOutcome(found)
if found {
conn, ok := cached.(*gitpod.APIoverJSONRPC)
Expand All @@ -137,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII
}
}

conn, err := p.connConstructor(token)
conn, err := p.connConstructor(ctx, token)
if err != nil {
return nil, fmt.Errorf("failed to create new connection to server: %w", err)
}

p.cache.Add(token, conn)
p.cache.Add(cacheKey, conn)
connectionPoolSize.Inc()

return conn, nil
}

func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey {
return conenctionPoolCacheKey{
token: token,
origin: origin,
}
}

func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) {
switch t.Type {
case auth.AccessTokenType:
Expand Down
Loading

0 comments on commit 0a7ca4c

Please sign in to comment.