-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use context to store and populate origin
- Loading branch information
Showing
10 changed files
with
198 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
) | ||
|
||
type contextKey int | ||
|
||
const ( | ||
originContextKey contextKey = iota | ||
) | ||
|
||
func ToContext(ctx context.Context, origin string) context.Context { | ||
return context.WithValue(ctx, originContextKey, origin) | ||
} | ||
|
||
func FromContext(ctx context.Context) string { | ||
if val, ok := ctx.Value(originContextKey).(string); ok { | ||
return val | ||
} | ||
|
||
return "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestToFromContext(t *testing.T) { | ||
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted") | ||
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/bufbuild/connect-go" | ||
) | ||
|
||
func NewInterceptor() *Interceptor { | ||
return &Interceptor{} | ||
} | ||
|
||
type Interceptor struct{} | ||
|
||
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { | ||
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { | ||
if req.Spec().IsClient { | ||
req.Header().Add("Origin", FromContext(ctx)) | ||
} else { | ||
origin := req.Header().Get("Origin") | ||
ctx = ToContext(ctx, origin) | ||
} | ||
|
||
return next(ctx, req) | ||
}) | ||
} | ||
|
||
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { | ||
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { | ||
conn := next(ctx, s) | ||
conn.RequestHeader().Add("Origin", FromContext(ctx)) | ||
|
||
return conn | ||
} | ||
} | ||
|
||
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { | ||
return func(ctx context.Context, conn connect.StreamingHandlerConn) error { | ||
origin := conn.RequestHeader().Get("Origin") | ||
ctx = ToContext(ctx, origin) | ||
|
||
return next(ctx, conn) | ||
} | ||
} |
36 changes: 36 additions & 0 deletions
36
components/public-api-server/pkg/origin/middleware_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) 2023 Gitpod GmbH. All rights reserved. | ||
// Licensed under the GNU Affero General Public License (AGPL). | ||
// See License.AGPL.txt in the project root for license information. | ||
|
||
package origin | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/bufbuild/connect-go" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestInterceptor_Unary(t *testing.T) { | ||
requestPaylaod := "request" | ||
origin := "my-origin" | ||
|
||
type response struct { | ||
origin string | ||
} | ||
|
||
handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { | ||
origin := FromContext(ctx) | ||
return connect.NewResponse(&response{origin: origin}), nil | ||
}) | ||
|
||
ctx := context.Background() | ||
request := connect.NewRequest(&requestPaylaod) | ||
request.Header().Add("Origin", origin) | ||
|
||
interceptor := NewInterceptor() | ||
resp, err := interceptor.WrapUnary(handler)(ctx, request) | ||
require.NoError(t, err) | ||
require.Equal(t, &response{origin: origin}, resp.Any()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.