Skip to content

Commit 394330f

Browse files
committed
Use context to store and populate origin
1 parent 6f6fedd commit 394330f

File tree

7 files changed

+143
-27
lines changed

7 files changed

+143
-27
lines changed

components/public-api-server/pkg/auth/context.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ const (
2525
type Token struct {
2626
Type TokenType
2727
Value string
28-
// Only relevant for CookieTokenType
29-
OriginHeader string
3028
}
3129

3230
func NewAccessToken(token string) Token {
@@ -36,11 +34,10 @@ func NewAccessToken(token string) Token {
3634
}
3735
}
3836

39-
func NewCookieToken(cookie string, origin string) Token {
37+
func NewCookieToken(cookie string) Token {
4038
return Token{
41-
Type: CookieTokenType,
42-
Value: cookie,
43-
OriginHeader: origin,
39+
Type: CookieTokenType,
40+
Value: cookie,
4441
}
4542
}
4643

components/public-api-server/pkg/auth/middleware.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ import (
1111
"github.com/bufbuild/connect-go"
1212
)
1313

14-
type AuthInterceptor struct {
14+
type Interceptor struct {
1515
accessToken string
1616
}
1717

18-
func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
18+
func (a *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
1919
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
2020
if req.Spec().IsClient {
2121
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
@@ -33,7 +33,7 @@ func (a *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
3333
})
3434
}
3535

36-
func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
36+
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
3737
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
3838
ctx = TokenToContext(ctx, NewAccessToken(a.accessToken))
3939
conn := next(ctx, s)
@@ -42,7 +42,7 @@ func (a *AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc)
4242
}
4343
}
4444

45-
func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
45+
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
4646
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
4747
token, err := tokenFromConn(ctx, conn)
4848
if err != nil {
@@ -54,7 +54,7 @@ func (a *AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc
5454

5555
// NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header
5656
func NewServerInterceptor() connect.Interceptor {
57-
return &AuthInterceptor{}
57+
return &Interceptor{}
5858
}
5959

6060
func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error) {
@@ -66,9 +66,8 @@ func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error
6666
}
6767

6868
cookie := req.Header().Get("Cookie")
69-
origin := req.Header().Get("Origin")
7069
if cookie != "" {
71-
return NewCookieToken(cookie, origin), nil
70+
return NewCookieToken(cookie), nil
7271
}
7372

7473
return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
@@ -83,17 +82,16 @@ func tokenFromConn(ctx context.Context, conn connect.StreamingHandlerConn) (Toke
8382
}
8483

8584
cookie := conn.RequestHeader().Get("Cookie")
86-
origin := conn.RequestHeader().Get("Origin")
8785
if cookie != "" {
88-
return NewCookieToken(cookie, origin), nil
86+
return NewCookieToken(cookie), nil
8987
}
9088

9189
return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
9290
}
9391

9492
// NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header
9593
func NewClientInterceptor(accessToken string) connect.Interceptor {
96-
return &AuthInterceptor{
94+
return &Interceptor{
9795
accessToken: accessToken,
9896
}
9997
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
)
10+
11+
type contextKey int
12+
13+
const (
14+
originContextKey contextKey = iota
15+
)
16+
17+
func ToContext(ctx context.Context, origin string) context.Context {
18+
return context.WithValue(ctx, originContextKey, origin)
19+
}
20+
21+
func FromContext(ctx context.Context) string {
22+
if val, ok := ctx.Value(originContextKey).(string); ok {
23+
return val
24+
}
25+
26+
return ""
27+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
10+
"github.com/bufbuild/connect-go"
11+
)
12+
13+
func NewInterceptor() *Interceptor {
14+
return &Interceptor{}
15+
}
16+
17+
type Interceptor struct{}
18+
19+
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
20+
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
21+
if req.Spec().IsClient {
22+
req.Header().Add("Origin", FromContext(ctx))
23+
} else {
24+
origin := req.Header().Get("Origin")
25+
ctx = ToContext(ctx, origin)
26+
}
27+
28+
return next(ctx, req)
29+
})
30+
}
31+
32+
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
33+
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
34+
conn := next(ctx, s)
35+
conn.RequestHeader().Add("Origin", FromContext(ctx))
36+
37+
return conn
38+
}
39+
}
40+
41+
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
42+
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
43+
origin := conn.RequestHeader().Get("Origin")
44+
ctx = ToContext(ctx, origin)
45+
46+
return next(ctx, conn)
47+
}
48+
}

components/public-api-server/pkg/proxy/conn.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/gitpod-io/gitpod/common-go/log"
1515
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
1616
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
17+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
1718
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
1819

1920
lru "github.com/hashicorp/golang-lru"
@@ -41,14 +42,14 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP
4142
opts := gitpod.ConnectToServerOpts{
4243
Context: ctx,
4344
Log: logger,
45+
Origin: origin.FromContext(ctx),
4446
}
4547

4648
switch token.Type {
4749
case auth.AccessTokenType:
4850
opts.Token = token.Value
4951
case auth.CookieTokenType:
5052
opts.Cookie = token.Value
51-
opts.Origin = token.OriginHeader
5253
default:
5354
return nil, errors.New("unknown token type")
5455
}
@@ -83,7 +84,7 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
8384

8485
return &ConnectionPool{
8586
cache: cache,
86-
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
87+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
8788
opts := gitpod.ConnectToServerOpts{
8889
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request
8990
Context: context.Background(),
@@ -99,7 +100,6 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
99100
opts.Token = token.Value
100101
case auth.CookieTokenType:
101102
opts.Cookie = token.Value
102-
opts.Origin = token.OriginHeader
103103
default:
104104
return nil, errors.New("unknown token type")
105105
}
@@ -120,15 +120,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
120120

121121
}
122122

123+
type conenctionPoolCacheKey struct {
124+
token auth.Token
125+
origin string
126+
}
127+
123128
type ConnectionPool struct {
124-
connConstructor func(token auth.Token) (gitpod.APIInterface, error)
129+
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error)
125130

126131
// cache stores token to connection mapping
127132
cache *lru.Cache
128133
}
129134

130135
func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
131-
cached, found := p.cache.Get(token)
136+
origin := origin.FromContext(ctx)
137+
138+
cacheKey := p.cacheKey(token, origin)
139+
cached, found := p.cache.Get(cacheKey)
132140
reportCacheOutcome(found)
133141
if found {
134142
conn, ok := cached.(*gitpod.APIoverJSONRPC)
@@ -137,17 +145,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII
137145
}
138146
}
139147

140-
conn, err := p.connConstructor(token)
148+
conn, err := p.connConstructor(ctx, token)
141149
if err != nil {
142150
return nil, fmt.Errorf("failed to create new connection to server: %w", err)
143151
}
144152

145-
p.cache.Add(token, conn)
153+
p.cache.Add(cacheKey, conn)
146154
connectionPoolSize.Inc()
147155

148156
return conn, nil
149157
}
150158

159+
func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey {
160+
return conenctionPoolCacheKey{
161+
token: token,
162+
origin: origin,
163+
}
164+
}
165+
151166
func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) {
152167
switch t.Type {
153168
case auth.AccessTokenType:

components/public-api-server/pkg/proxy/conn_test.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
1313
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
14+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
1415
"github.com/golang/mock/gomock"
1516
lru "github.com/hashicorp/golang-lru"
1617
"github.com/stretchr/testify/require"
@@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) {
2526
require.NoError(t, err)
2627
pool := &ConnectionPool{
2728
cache: cache,
28-
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
29+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
2930
return srv, nil
3031
},
3132
}
@@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) {
4546
_, err = pool.Get(context.Background(), bazToken)
4647
require.NoError(t, err)
4748
require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")
48-
require.True(t, pool.cache.Contains(barToken))
49-
require.True(t, pool.cache.Contains(bazToken))
49+
require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))
50+
require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))
51+
}
52+
53+
func TestConnectionPool_ByDistinctOrigins(t *testing.T) {
54+
ctrl := gomock.NewController(t)
55+
defer ctrl.Finish()
56+
srv := gitpod.NewMockAPIInterface(ctrl)
57+
58+
cache, err := lru.New(2)
59+
require.NoError(t, err)
60+
pool := &ConnectionPool{
61+
cache: cache,
62+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
63+
return srv, nil
64+
},
65+
}
66+
67+
token := auth.NewAccessToken("foo")
68+
69+
ctxWithOriginA := origin.ToContext(context.Background(), "originA")
70+
ctxWithOriginB := origin.ToContext(context.Background(), "originB")
71+
72+
_, err = pool.Get(ctxWithOriginA, token)
73+
require.NoError(t, err)
74+
require.Equal(t, 1, pool.cache.Len())
75+
76+
_, err = pool.Get(ctxWithOriginB, token)
77+
require.NoError(t, err)
78+
require.Equal(t, 2, pool.cache.Len())
5079
}
5180

5281
func TestEndpointBasedOnToken(t *testing.T) {
@@ -57,7 +86,7 @@ func TestEndpointBasedOnToken(t *testing.T) {
5786
require.NoError(t, err)
5887
require.Equal(t, "wss://server:3000/v1", endpointForAccessToken)
5988

60-
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo", "server"), u)
89+
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u)
6190
require.NoError(t, err)
6291
require.Equal(t, "wss://server:3000/gitpod", endpointForCookie)
6392
}

components/public-api-server/pkg/server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
2929
"github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice"
3030
"github.com/gitpod-io/gitpod/public-api-server/pkg/oidc"
31+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
3132
"github.com/gitpod-io/gitpod/public-api-server/pkg/proxy"
3233
"github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks"
3334
"github.com/sirupsen/logrus"
@@ -154,6 +155,7 @@ func register(srv *baseserver.Server, deps *registerDependencies) error {
154155
NewMetricsInterceptor(connectMetrics),
155156
NewLogInterceptor(log.Log),
156157
auth.NewServerInterceptor(),
158+
origin.NewInterceptor(),
157159
),
158160
}
159161

0 commit comments

Comments
 (0)