Skip to content

Commit

Permalink
vcs: Derive per tenant key (grafana#3293)
Browse files Browse the repository at this point in the history
Curretnly we use the same global session encryption secret, for each
teant. In order to ensure tenant isolation, this change will derive a
custom secret per tenant.

Note: This change will require all users to reauthenticate, as the the
previous secret won't be able to be decrypted anymore.
  • Loading branch information
simonswine authored May 13, 2024
1 parent 7aaf8e1 commit 9e2bb77
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 34 deletions.
22 changes: 17 additions & 5 deletions pkg/querier/vcs/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ func (q *Service) GithubLogin(ctx context.Context, req *connect.Request[vcsv1.Gi
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to authorize with GitHub"))
}

encryptionKey, err := deriveEncryptionKeyForContext(ctx)
if err != nil {
q.logger.Log("err", err, "msg", "failed to derive encryption key")
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to authorize with GitHub"))
}

token, err := cfg.Exchange(ctx, req.Msg.AuthorizationCode)
if err != nil {
q.logger.Log("err", err, "msg", "failed to exchange authorization code with GitHub")
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("failed to authorize with GitHub"))
}

cookie, err := encodeToken(token)
cookie, err := encodeToken(token, encryptionKey)
if err != nil {
q.logger.Log("err", err, "msg", "failed to encode GitHub OAuth token")
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to authorize with GitHub"))
Expand All @@ -63,7 +69,7 @@ func (q *Service) GithubLogin(ctx context.Context, req *connect.Request[vcsv1.Gi
}

func (q *Service) GithubRefresh(ctx context.Context, req *connect.Request[vcsv1.GithubRefreshRequest]) (*connect.Response[vcsv1.GithubRefreshResponse], error) {
token, err := tokenFromRequest(req)
token, err := tokenFromRequest(ctx, req)
if err != nil {
q.logger.Log("err", err, "msg", "failed to extract token from request")
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("invalid token"))
Expand All @@ -83,7 +89,13 @@ func (q *Service) GithubRefresh(ctx context.Context, req *connect.Request[vcsv1.

newToken := githubToken.toOAuthToken()

cookie, err := encodeToken(newToken)
derivedKey, err := deriveEncryptionKeyForContext(ctx)
if err != nil {
q.logger.Log("err", err, "msg", "failed to derive encryption key")
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to process token"))
}

cookie, err := encodeToken(newToken, derivedKey)
if err != nil {
q.logger.Log("err", err, "msg", "failed to encode GitHub OAuth token")
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to refresh token"))
Expand All @@ -96,7 +108,7 @@ func (q *Service) GithubRefresh(ctx context.Context, req *connect.Request[vcsv1.
}

func (q *Service) GetFile(ctx context.Context, req *connect.Request[vcsv1.GetFileRequest]) (*connect.Response[vcsv1.GetFileResponse], error) {
token, err := tokenFromRequest(req)
token, err := tokenFromRequest(ctx, req)
if err != nil {
q.logger.Log("err", err, "msg", "failed to extract token from request")
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("invalid token"))
Expand Down Expand Up @@ -141,7 +153,7 @@ func (q *Service) GetFile(ctx context.Context, req *connect.Request[vcsv1.GetFil
}

func (q *Service) GetCommit(ctx context.Context, req *connect.Request[vcsv1.GetCommitRequest]) (*connect.Response[vcsv1.GetCommitResponse], error) {
token, err := tokenFromRequest(req)
token, err := tokenFromRequest(ctx, req)
if err != nil {
q.logger.Log("err", err, "msg", "failed to extract token from request")
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("invalid token"))
Expand Down
52 changes: 41 additions & 11 deletions pkg/querier/vcs/token.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package vcs

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -12,21 +15,43 @@ import (

"connectrpc.com/connect"
"golang.org/x/oauth2"

"github.com/grafana/pyroscope/pkg/tenant"
)

const (
sessionCookieName = "GitSession"
)

var (
gitSessionSecret = []byte(os.Getenv("GITHUB_SESSION_SECRET"))
)

type gitSessionTokenCookie struct {
Metadata string `json:"metadata"`
ExpiryTimestamp int64 `json:"expiry"`
}

const envVarGithubSessionSecret = "GITHUB_SESSION_SECRET"

var githubSessionSecret = []byte(os.Getenv(envVarGithubSessionSecret))

// derives a per tenant key from the global session secret using sha256
func deriveEncryptionKeyForContext(ctx context.Context) ([]byte, error) {
tenantID, err := tenant.ExtractTenantIDFromContext(ctx)
if err != nil {
return nil, err
}
if len(tenantID) == 0 {
return nil, errors.New("tenantID is empty")
}

if len(githubSessionSecret) == 0 {
return nil, errors.New(envVarGithubSessionSecret + " is empty")
}
h := sha256.New()
h.Write(githubSessionSecret)
h.Write([]byte{':'})
h.Write([]byte(tenantID))
return h.Sum(nil), nil
}

// getStringValueFrom gets a string value from url.Values. It will fail if the
// key is missing or the key's value is an empty string.
func getStringValueFrom(values url.Values, key string) (string, error) {
Expand Down Expand Up @@ -59,22 +84,27 @@ func getDurationValueFrom(values url.Values, key string, scalar time.Duration) (
}

// tokenFromRequest decodes an OAuth token from a request.
func tokenFromRequest(req connect.AnyRequest) (*oauth2.Token, error) {
func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (*oauth2.Token, error) {
cookie, err := (&http.Request{Header: req.Header()}).Cookie(sessionCookieName)
if err != nil {
return nil, fmt.Errorf("failed to read cookie %s: %w", sessionCookieName, err)
}

token, err := decodeToken(cookie.Value)
derivedKey, err := deriveEncryptionKeyForContext(ctx)
if err != nil {
return nil, err
}

token, err := decodeToken(cookie.Value, derivedKey)
if err != nil {
return nil, err
}
return token, nil
}

// encodeToken encrypts then base64 encodes an OAuth token.
func encodeToken(token *oauth2.Token) (*http.Cookie, error) {
encrypted, err := encryptToken(token, gitSessionSecret)
func encodeToken(token *oauth2.Token, key []byte) (*http.Cookie, error) {
encrypted, err := encryptToken(token, key)
if err != nil {
return nil, err
}
Expand All @@ -100,7 +130,7 @@ func encodeToken(token *oauth2.Token) (*http.Cookie, error) {
}

// decodeToken base64 decodes and decrypts a OAuth token.
func decodeToken(value string) (*oauth2.Token, error) {
func decodeToken(value string, key []byte) (*oauth2.Token, error) {
var token *oauth2.Token

decoded, err := base64.StdEncoding.DecodeString(value)
Expand All @@ -114,15 +144,15 @@ func decodeToken(value string) (*oauth2.Token, error) {
// This may be a legacy cookie. Legacy cookies aren't base64 encoded
// JSON objects, but rather a base64 encoded crypto hash.
var innerErr error
token, innerErr = decryptToken(value, gitSessionSecret)
token, innerErr = decryptToken(value, key)
if innerErr != nil {
// Legacy fallback failed, return the original error.
return nil, err
}
return token, nil
}

token, err = decryptToken(sessionToken.Metadata, gitSessionSecret)
token, err = decryptToken(sessionToken.Metadata, key)
if err != nil {
return nil, err
}
Expand Down
101 changes: 83 additions & 18 deletions pkg/querier/vcs/token_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vcs

import (
"context"
"encoding/base64"
"net/http"
"net/url"
Expand All @@ -12,6 +13,7 @@ import (
"golang.org/x/oauth2"

vcsv1 "github.com/grafana/pyroscope/api/gen/proto/go/vcs/v1"
"github.com/grafana/pyroscope/pkg/tenant"
)

func Test_getStringValueFrom(t *testing.T) {
Expand Down Expand Up @@ -133,8 +135,14 @@ func Test_getDurationValueFrom(t *testing.T) {
}

func Test_tokenFromRequest(t *testing.T) {
ctx := newTestContext()

t.Run("token exists in request", func(t *testing.T) {
gitSessionSecret = []byte("16_byte_key_XXXX")
githubSessionSecret = []byte("16_byte_key_XXXX")

derivedKey, err := deriveEncryptionKeyForContext(ctx)
require.NoError(t, err)

wantToken := &oauth2.Token{
AccessToken: "my_access_token",
TokenType: "my_token_type",
Expand All @@ -144,36 +152,41 @@ func Test_tokenFromRequest(t *testing.T) {

// The type of request here doesn't matter.
req := connect.NewRequest(&vcsv1.GetFileRequest{})
req.Header().Add("Cookie", testEncodeCookie(t, wantToken).String())
req.Header().Add("Cookie", testEncodeCookie(t, derivedKey, wantToken).String())

gotToken, err := tokenFromRequest(req)
gotToken, err := tokenFromRequest(ctx, req)
require.NoError(t, err)
require.Equal(t, *wantToken, *gotToken)
})

t.Run("token does not exist in request", func(t *testing.T) {
gitSessionSecret = []byte("16_byte_key_XXXX")
githubSessionSecret = []byte("16_byte_key_XXXX")
wantErr := "failed to read cookie GitSession: http: named cookie not present"

// The type of request here doesn't matter.
req := connect.NewRequest(&vcsv1.GetFileRequest{})

_, err := tokenFromRequest(req)
_, err := tokenFromRequest(ctx, req)
require.Error(t, err)
require.EqualError(t, err, wantErr)
})
}

func Test_encodeToken(t *testing.T) {
gitSessionSecret = []byte("16_byte_key_XXXX")
githubSessionSecret = []byte("16_byte_key_XXXX")
ctx := newTestContext()

derivedKey, err := deriveEncryptionKeyForContext(ctx)
require.NoError(t, err)

token := &oauth2.Token{
AccessToken: "my_access_token",
TokenType: "my_token_type",
RefreshToken: "my_refresh_token",
Expiry: time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
}

got, err := encodeToken(token)
got, err := encodeToken(token, derivedKey)
require.NoError(t, err)
require.Equal(t, sessionCookieName, got.Name)
require.NotEmpty(t, got.Value)
Expand All @@ -183,7 +196,11 @@ func Test_encodeToken(t *testing.T) {
}

func Test_decodeToken(t *testing.T) {
gitSessionSecret = []byte("16_byte_key_XXXX")
githubSessionSecret = []byte("16_byte_key_XXXX")

ctx := newTestContext()
derivedKey, err := deriveEncryptionKeyForContext(ctx)
require.NoError(t, err)

t.Run("valid token", func(t *testing.T) {
want := &oauth2.Token{
Expand All @@ -192,9 +209,9 @@ func Test_decodeToken(t *testing.T) {
RefreshToken: "my_refresh_token",
Expiry: time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
}
cookie := testEncodeCookie(t, want)
cookie := testEncodeCookie(t, derivedKey, want)

got, err := decodeToken(cookie.Value)
got, err := decodeToken(cookie.Value, derivedKey)
require.NoError(t, err)
require.Equal(t, want, got)
})
Expand All @@ -206,43 +223,91 @@ func Test_decodeToken(t *testing.T) {
RefreshToken: "my_refresh_token",
Expiry: time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
}
cookie := testEncodeLegacyCookie(t, want)
cookie := testEncodeLegacyCookie(t, derivedKey, want)

got, err := decodeToken(cookie.Value)
got, err := decodeToken(cookie.Value, derivedKey)
require.NoError(t, err)
require.Equal(t, want, got)
})

t.Run("invalid base64 encoding", func(t *testing.T) {
illegalBase64Encoding := "xx==="

_, err := decodeToken(illegalBase64Encoding)
_, err := decodeToken(illegalBase64Encoding, derivedKey)
require.Error(t, err)
require.EqualError(t, err, "illegal base64 data at input byte 4")
})

t.Run("invalid json encoding", func(t *testing.T) {
illegalJSON := base64.StdEncoding.EncodeToString([]byte("illegal json value"))

_, err := decodeToken(illegalJSON)
_, err := decodeToken(illegalJSON, derivedKey)
require.Error(t, err)
require.EqualError(t, err, "invalid character 'i' looking for beginning of value")
})
}

func testEncodeCookie(t *testing.T, token *oauth2.Token) *http.Cookie {
func Test_tenantIsolation(t *testing.T) {
githubSessionSecret = []byte("16_byte_key_XXXX")

var (
ctxA = newTestContextWithTenantID("tenant_a")
ctxB = newTestContextWithTenantID("tenant_b")
)

derivedKeyA, err := deriveEncryptionKeyForContext(ctxA)
require.NoError(t, err)

encodedTokenA, err := encodeToken(&oauth2.Token{
AccessToken: "so_secret",
}, derivedKeyA)
require.NoError(t, err)

req := connect.NewRequest(&vcsv1.GetFileRequest{})
req.Header().Add("Cookie", encodedTokenA.String())

tA, err := tokenFromRequest(ctxA, req)
require.NoError(t, err)
require.Equal(t, "so_secret", tA.AccessToken)

_, err = tokenFromRequest(ctxB, req)
require.ErrorContains(t, err, "message authentication failed")

}

func Test_StillCompatbile(t *testing.T) {
githubSessionSecret = []byte("16_byte_key_XXXX")

ctx := newTestContextWithTenantID("tenant_a")
req := connect.NewRequest(&vcsv1.GetFileRequest{})
req.Header().Add("Cookie", "GitSession=eyJtZXRhZGF0YSI6Im12N0d1OHlIanZxdWdQMmF5TnJaYXd1SXNyQXFmUUVIMVhGS1RkejVlZWtob1NRV1JUM3hVZGRuMndUemhQZ05oWktRVkpjcVh5SVJDSnFmTTV3WTJyNmR3R21rZkRhL2FORjhRZ0lJcU1oa1hPbGFEdXNwcFE9PSJ9Cg==")

realToken, err := tokenFromRequest(ctx, req)
require.NoError(t, err)
require.Equal(t, "so_secret", realToken.AccessToken)
}

func newTestContext() context.Context {
return newTestContextWithTenantID("test_tenant_id")
}

func newTestContextWithTenantID(tenantID string) context.Context {
return tenant.InjectTenantID(context.Background(), tenantID)
}

func testEncodeCookie(t *testing.T, key []byte, token *oauth2.Token) *http.Cookie {
t.Helper()

encoded, err := encodeToken(token)
encoded, err := encodeToken(token, key)
require.NoError(t, err)

return encoded
}

func testEncodeLegacyCookie(t *testing.T, token *oauth2.Token) *http.Cookie {
func testEncodeLegacyCookie(t *testing.T, key []byte, token *oauth2.Token) *http.Cookie {
t.Helper()

encrypted, err := encryptToken(token, gitSessionSecret)
encrypted, err := encryptToken(token, key)
require.NoError(t, err)

return &http.Cookie{
Expand Down

0 comments on commit 9e2bb77

Please sign in to comment.