From 09693a600d099ea1b0ad4c57905dee1d9e177867 Mon Sep 17 00:00:00 2001 From: Marceau Lecomte Date: Fri, 14 Jun 2024 19:47:12 +0200 Subject: [PATCH] Encrypt wallet shares with client stored key Using metadata --- go.mod | 1 - server/database/models.go | 14 ++- server/database/query.sql.go | 51 +++++----- server/handlers.go | 33 ++++--- server/main.go | 27 +++--- server/sqlc/query.sql | 4 +- server/sqlc/schema.sql | 4 +- server/vault/wallet.go | 164 ++++++++++++++++++++++++-------- test/integration/tss_test.go | 27 ++---- test/integration/wallet_test.go | 4 +- utils/types/context.go | 3 + 11 files changed, 207 insertions(+), 125 deletions(-) create mode 100644 utils/types/context.go diff --git a/go.mod b/go.mod index d72bed7..e1f0412 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/rs/cors v1.7.0 - github.com/tabbed/pqtype v0.0.0-00010101000000-000000000000 golang.org/x/crypto v0.23.0 golang.org/x/mobile v0.0.0-20230922142353-e2f452493d57 google.golang.org/protobuf v1.34.1 diff --git a/server/database/models.go b/server/database/models.go index e2b535d..16835e3 100644 --- a/server/database/models.go +++ b/server/database/models.go @@ -4,9 +4,7 @@ package database -import ( - "encoding/json" -) +import () type Device struct { ID int64 @@ -21,9 +19,9 @@ type User struct { } type Wallet struct { - ID int64 - UserID int64 - PublicAddress string - Share string - Params json.RawMessage + ID int64 + UserID int64 + PublicAddress string + EncryptedDkgResults []byte + Nonce []byte } diff --git a/server/database/query.sql.go b/server/database/query.sql.go index 1e2a724..c3ebe6e 100644 --- a/server/database/query.sql.go +++ b/server/database/query.sql.go @@ -8,9 +8,6 @@ package database import ( "context" "database/sql" - "encoding/json" - - "github.com/tabbed/pqtype" ) const addDevice = `-- name: AddDevice :one @@ -55,33 +52,33 @@ func (q *Queries) AddUser(ctx context.Context, foreignKey string) (User, error) } const addWallet = `-- name: AddWallet :one -INSERT INTO wallets (user_id, public_address, share, params) +INSERT INTO wallets (user_id, public_address, encrypted_dkg_results, nonce) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING -RETURNING id, user_id, public_address, share, params +RETURNING id, user_id, public_address, encrypted_dkg_results, nonce ` type AddWalletParams struct { - UserId int64 - PublicAddress string - Share string - Params json.RawMessage + UserId int64 + PublicAddress string + EncryptedDkgResults []byte + Nonce []byte } func (q *Queries) AddWallet(ctx context.Context, arg AddWalletParams) (Wallet, error) { row := q.db.QueryRowContext(ctx, addWallet, arg.UserId, arg.PublicAddress, - arg.Share, - arg.Params, + arg.EncryptedDkgResults, + arg.Nonce, ) var i Wallet err := row.Scan( &i.ID, &i.UserID, &i.PublicAddress, - &i.Share, - &i.Params, + &i.EncryptedDkgResults, + &i.Nonce, ) return i, err } @@ -161,18 +158,18 @@ func (q *Queries) GetUserDevices(ctx context.Context, userid int64) ([]Device, e } const getUserSigningParameters = `-- name: GetUserSigningParameters :one -SELECT wallets.id, wallets.user_id, wallets.public_address, wallets.share, wallets.params +SELECT wallets.id, wallets.user_id, wallets.public_address, wallets.encrypted_dkg_results, wallets.nonce FROM users LEFT JOIN wallets ON users.id = wallets.user_id WHERE users.foreign_key = $1 ` type GetUserSigningParametersRow struct { - ID sql.NullInt64 - UserID sql.NullInt64 - PublicAddress sql.NullString - Share sql.NullString - Params pqtype.NullRawMessage + ID sql.NullInt64 + UserID sql.NullInt64 + PublicAddress sql.NullString + EncryptedDkgResults []byte + Nonce []byte } func (q *Queries) GetUserSigningParameters(ctx context.Context, foreignkey string) (GetUserSigningParametersRow, error) { @@ -182,14 +179,14 @@ func (q *Queries) GetUserSigningParameters(ctx context.Context, foreignkey strin &i.ID, &i.UserID, &i.PublicAddress, - &i.Share, - &i.Params, + &i.EncryptedDkgResults, + &i.Nonce, ) return i, err } const getUserWallets = `-- name: GetUserWallets :many -SELECT id, user_id, public_address, share, params FROM wallets +SELECT id, user_id, public_address, encrypted_dkg_results, nonce FROM wallets WHERE user_id = $1 ` @@ -206,8 +203,8 @@ func (q *Queries) GetUserWallets(ctx context.Context, userid int64) ([]Wallet, e &i.ID, &i.UserID, &i.PublicAddress, - &i.Share, - &i.Params, + &i.EncryptedDkgResults, + &i.Nonce, ); err != nil { return nil, err } @@ -223,7 +220,7 @@ func (q *Queries) GetUserWallets(ctx context.Context, userid int64) ([]Wallet, e } const getWalletByAddress = `-- name: GetWalletByAddress :one -SELECT id, user_id, public_address, share, params FROM wallets +SELECT id, user_id, public_address, encrypted_dkg_results, nonce FROM wallets WHERE public_address = $1 ` @@ -234,8 +231,8 @@ func (q *Queries) GetWalletByAddress(ctx context.Context, publicaddress string) &i.ID, &i.UserID, &i.PublicAddress, - &i.Share, - &i.Params, + &i.EncryptedDkgResults, + &i.Nonce, ) return i, err } diff --git a/server/handlers.go b/server/handlers.go index e1e4b42..4e759ac 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -20,8 +20,6 @@ import ( "nhooyr.io/websocket" ) -type ContextKey string - // identityMiddleware is a middleware used to get the userId from auth provider based on a generic bearer token provided by the client // used by /identify and /authorize func (server *Server) identityMiddleware(next http.Handler) http.Handler { @@ -64,7 +62,7 @@ func (server *Server) identityMiddleware(next http.Handler) http.Handler { // Store userId in context for next request in the stack - ctx = context.WithValue(ctx, ContextKey("userId"), userId) + ctx = context.WithValue(ctx, types.ContextKey("userId"), userId) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -72,6 +70,7 @@ func (server *Server) identityMiddleware(next http.Handler) http.Handler { // ServeWasm is responsible for serving the wasm module func (server *Server) ServeWasm(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/wasm") + w.Header().Set("Access-Control-Allow-Origin", server._config.ClientOrigin) w.Write(server._wasm) } @@ -79,7 +78,7 @@ func (server *Server) ServeWasm(w http.ResponseWriter, r *http.Request) { // It uses identityMiddleware to get the userId from auth provider based on a generic bearer token provided by the client, then returns it func (server *Server) IdentifyHandler(w http.ResponseWriter, r *http.Request) { // Get userId from context - userId, ok := r.Context().Value(ContextKey("userId")).(string) + userId, ok := r.Context().Value(types.ContextKey("userId")).(string) if !ok { log.Println("Authorization info not found") http.Error(w, "Authorization info not found", http.StatusUnauthorized) @@ -100,14 +99,14 @@ type tokenParameters struct { // It then creates an access token linked to that userId, stores it in cache and returns it func (server *Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) { // Get userId from context - userId, ok := r.Context().Value(ContextKey("userId")).(string) + userId, ok := r.Context().Value(types.ContextKey("userId")).(string) if !ok { http.Error(w, "Authorization info not found", http.StatusUnauthorized) return } // Get metadata from context - metadata, ok := r.Context().Value(ContextKey("metadata")).(string) + metadata, ok := r.Context().Value(types.ContextKey("metadata")).(string) if !ok { http.Error(w, "Authorization info not found", http.StatusUnauthorized) return @@ -161,9 +160,9 @@ func (server *Server) authMiddleware(next http.Handler) http.Handler { // Add the userId and token to the context ctx := r.Context() - ctx = context.WithValue(ctx, ContextKey("userId"), tokenParams.userId) - ctx = context.WithValue(ctx, ContextKey("metadata"), tokenParams.metadata) - ctx = context.WithValue(ctx, ContextKey("token"), tokenParam[0]) + ctx = context.WithValue(ctx, types.ContextKey("userId"), tokenParams.userId) + ctx = context.WithValue(ctx, types.ContextKey("metadata"), tokenParams.metadata) + ctx = context.WithValue(ctx, types.ContextKey("token"), tokenParam[0]) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -173,13 +172,13 @@ func (server *Server) authMiddleware(next http.Handler) http.Handler { // stores the result of dkg in DB (new wallet) func (server *Server) DkgHandler(w http.ResponseWriter, r *http.Request) { // Get userId and access token from context - userId, ok := r.Context().Value(ContextKey("userId")).(string) + userId, ok := r.Context().Value(types.ContextKey("userId")).(string) if !ok { http.Error(w, "Authorization info not found", http.StatusUnauthorized) return } - token, ok := r.Context().Value(ContextKey("token")).(string) + token, ok := r.Context().Value(types.ContextKey("token")).(string) if !ok { http.Error(w, "Authorization info not found", http.StatusUnauthorized) return @@ -299,7 +298,7 @@ func (server *Server) DkgHandler(w http.ResponseWriter, r *http.Request) { // In the future: verifies that client has been able to store wallet; if not, remove in DB // Idea: "validated" status for the wallet, which becomes True after calling DkgTwo; if False, can be overwritten func (server *Server) DkgTwoHandler(w http.ResponseWriter, r *http.Request) { - token, ok := r.Context().Value(ContextKey("token")).(string) + token, ok := r.Context().Value(types.ContextKey("token")).(string) if !ok { http.Error(w, "Authorization info not found", http.StatusUnauthorized) return @@ -328,14 +327,14 @@ func (server *Server) DkgTwoHandler(w http.ResponseWriter, r *http.Request) { // requires a hex-encoded message to be signed (provided in URL parameter) func (server *Server) SignHandler(w http.ResponseWriter, r *http.Request) { // Get userId and access token from context - userId, ok := r.Context().Value(ContextKey("userId")).(string) + userId, ok := r.Context().Value(types.ContextKey("userId")).(string) if !ok { // If there's no userID in the context, report an error and return. http.Error(w, "Authorization info not found", http.StatusUnauthorized) return } - token, ok := r.Context().Value(ContextKey("token")).(string) + token, ok := r.Context().Value(types.ContextKey("token")).(string) if !ok { // If there's no token in the context, report an error and return. http.Error(w, "Authorization info not found", http.StatusUnauthorized) @@ -439,14 +438,14 @@ func (server *Server) RecoverHandler(w http.ResponseWriter, r *http.Request) { } // Get userId and access token from context - userId, ok := r.Context().Value(ContextKey("userId")).(string) + userId, ok := r.Context().Value(types.ContextKey("userId")).(string) if !ok { // If there's no userID in the context, report an error and return. http.Error(w, "Authorization info not found", http.StatusUnauthorized) return } - token, ok := r.Context().Value(ContextKey("token")).(string) + token, ok := r.Context().Value(types.ContextKey("token")).(string) if !ok { // If there's no token in the context, report an error and return. http.Error(w, "Authorization info not found", http.StatusUnauthorized) @@ -569,7 +568,7 @@ func (server *Server) headerMiddleware(next http.Handler) http.Handler { // Remove the prefix and use the remaining part as the context key key := strings.ToLower(strings.TrimPrefix(name, headerPrefix)) // Add the first header value to the context - ctx = context.WithValue(ctx, ContextKey(key), values[0]) + ctx = context.WithValue(ctx, types.ContextKey(key), values[0]) } } diff --git a/server/main.go b/server/main.go index e397489..3c9bdb8 100644 --- a/server/main.go +++ b/server/main.go @@ -15,7 +15,6 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/patrickmn/go-cache" - "github.com/rs/cors" ) type Server struct { @@ -54,22 +53,13 @@ func NewServer(vault Vault, config *Config, wasmBinary []byte, logging bool) *Se // Router - _cors := cors.New(cors.Options{ - AllowedOrigins: []string{server._config.ClientOrigin}, - AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - // AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, - ExposedHeaders: []string{"Link"}, - AllowCredentials: true, - MaxAge: 300, - }) - r := chi.NewRouter() // global middlewares if logging { r.Use(middleware.Logger) } - r.Use(_cors.Handler) + r.Use(server.corsMiddleware) // r.Use(cors.Default().Handler) r.Use(server.headerMiddleware) @@ -154,3 +144,18 @@ type Config struct { SupabaseUrl string SupabaseApiKey string } + +func (server *Server) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", server._config.ClientOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, M-METADATA") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/server/sqlc/query.sql b/server/sqlc/query.sql index 7af3b95..67956ed 100644 --- a/server/sqlc/query.sql +++ b/server/sqlc/query.sql @@ -10,8 +10,8 @@ ON CONFLICT DO NOTHING RETURNING *; -- name: AddWallet :one -INSERT INTO wallets (user_id, public_address, share, params) -VALUES (sqlc.arg('UserId'), sqlc.arg('PublicAddress'), sqlc.arg('Share'), sqlc.arg('Params')) +INSERT INTO wallets (user_id, public_address, encrypted_dkg_results, nonce) +VALUES (sqlc.arg('UserId'), sqlc.arg('PublicAddress'), sqlc.arg('EncryptedDkgResults'), sqlc.arg('Nonce')) ON CONFLICT DO NOTHING RETURNING *; diff --git a/server/sqlc/schema.sql b/server/sqlc/schema.sql index d6da7d7..ef7c226 100644 --- a/server/sqlc/schema.sql +++ b/server/sqlc/schema.sql @@ -9,8 +9,8 @@ CREATE TABLE wallets ( id BIGSERIAL PRIMARY KEY, user_id bigint NOT NULL REFERENCES users(id) ON DELETE RESTRICT ON UPDATE RESTRICT, public_address text NOT NULL DEFAULT '', - share text NOT NULL DEFAULT '', - params jsonb NOT NULL + encrypted_dkg_results bytea NOT NULL DEFAULT E'\\x', + nonce bytea NOT NULL DEFAULT E'\\x' ); -- CREATE UNIQUE INDEX wallet_identifier ON public.wallets USING btree (user_id, public_address); diff --git a/server/vault/wallet.go b/server/vault/wallet.go index 54179f3..c21baa9 100644 --- a/server/vault/wallet.go +++ b/server/vault/wallet.go @@ -2,7 +2,13 @@ package vault import ( "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" "encoding/json" + "errors" + "io" "log" "github.com/getmeemaw/meemaw/server/database" @@ -18,80 +24,160 @@ func NewVault(queries *database.Queries) *Vault { return &Vault{_queries: queries} } +// WalletExists verifies if a wallet already exists +func (vault *Vault) WalletExists(ctx context.Context, userId string) error { + _, err := vault._queries.GetUserByForeignKey(ctx, userId) + return err +} + +/////// + +// StoreWallet upserts a wallet (if it already exists, it does nothing, no error returned) +// Tested in integration tests (with throw away db) +func (vault *Vault) StoreWallet(ctx context.Context, userAgent string, foreignKey string, dkgResults *tss.DkgResult) (string, error) { + + // Encode dkgResults to json + jsonDkgResults, err := json.Marshal(dkgResults) + if err != nil { + log.Println("could not marshal dkgResults to json") + return "", err + } + + // Generate client key : + clientKey := make([]byte, 32) // return to client + if _, err := io.ReadFull(rand.Reader, clientKey); err != nil { + return "", err + } + + // Encrypt dkgResults with client key (so that server shares are not fully exposed in case of a breach) + nonceClient, ClientEncryptedDkgResults, err := encryptAES(jsonDkgResults, clientKey) + if err != nil { + log.Println("error while encrypting with client key:", err) + return "", err + } + + // Store in DB + user, err := vault._queries.AddUser(ctx, foreignKey) + if err != nil { + return "", err + } + + walletQueryParams := database.AddWalletParams{ + UserId: user.ID, + PublicAddress: dkgResults.Address, + EncryptedDkgResults: ClientEncryptedDkgResults, + Nonce: nonceClient, + } + + wallet, err := vault._queries.AddWallet(ctx, walletQueryParams) + if err != nil { + return "", err + } + + deviceQueryParams := database.AddDeviceParams{ + UserId: user.ID, + WalletId: wallet.ID, + UserAgent: userAgent, + } + + _, err = vault._queries.AddDevice(ctx, deviceQueryParams) + if err != nil { + return "", err + } + + return hex.EncodeToString(clientKey), nil +} + // RetrieveWallet retrieves a wallet from DB based on the userID of the user (which is a loose foreign key, the format will depend on the auth provider) // Tested in integration tests (with throw away db) func (vault *Vault) RetrieveWallet(ctx context.Context, foreignKey string) (*tss.DkgResult, error) { + + // get dkgResults res, err := vault._queries.GetUserSigningParameters(ctx, foreignKey) if err != nil { log.Println("error getting signing params:", err) return nil, &types.ErrNotFound{} } - // log.Println("GetUserSigningParameters res.Params.RawMessage:", string(res.Params.RawMessage)) + // get client key from context + clientKeyStr, ok := ctx.Value(types.ContextKey("metadata")).(string) + if !ok { + return nil, errors.New("could not find customer identifier") + } + + clientKey, err := hex.DecodeString(clientKeyStr) + if err != nil { + log.Println("error hex decoding clientKey") + return nil, err + } - var signingParams tss.SigningParameters - err = json.Unmarshal(res.Params.RawMessage, &signingParams) + // decrypt dkg results + jsonDkgResults, err := decryptAES(res.Nonce, res.EncryptedDkgResults, clientKey) if err != nil { - log.Println("error unmarshaling signing params:", err) + log.Println("could not decrypt AES using clientKey:", err) return nil, err } - dkgResult := tss.DkgResult{ - Pubkey: signingParams.Pubkey, - BKs: signingParams.BKs, - Share: res.Share.String, - Address: res.PublicAddress.String, + // decode json into *tss.DkgResult + dkgResult := &tss.DkgResult{} + err = json.Unmarshal(jsonDkgResults, dkgResult) + if err != nil { + log.Println("could not unmarshal jsonDkgResults") + return nil, err } - return &dkgResult, nil + return dkgResult, nil } -// StoreWallet upserts a wallet (if it already exists, it does nothing, no error returned) -// Tested in integration tests (with throw away db) -func (vault *Vault) StoreWallet(ctx context.Context, userAgent string, userId string, dkgResults *tss.DkgResult) (string, error) { - user, err := vault._queries.AddUser(ctx, userId) +// Encrypt a plaintext message using AES-GCM. +func encryptAES(plaintext, key []byte) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) if err != nil { - return "", err + return nil, nil, err } - signingParams := tss.SigningParameters{ - Pubkey: dkgResults.Pubkey, - BKs: dkgResults.BKs, + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err } - params, err := json.Marshal(signingParams) + // Generate a random nonce. Ensure it is unique for each encryption with the same key. + nonce, err := generateRandomBytes(aesGCM.NonceSize()) if err != nil { - return "", err + return nil, nil, err } - walletQueryParams := database.AddWalletParams{ - UserId: user.ID, - PublicAddress: dkgResults.Address, - Share: dkgResults.Share, - Params: params, - } + // Encrypt the plaintext using the nonce. + ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) + return nonce, ciphertext, nil +} - wallet, err := vault._queries.AddWallet(ctx, walletQueryParams) +// Decrypt a ciphertext message using AES-GCM. +func decryptAES(nonce, ciphertext, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) if err != nil { - return "", err + return nil, err } - deviceQueryParams := database.AddDeviceParams{ - UserId: user.ID, - WalletId: wallet.ID, - UserAgent: userAgent, + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err } - _, err = vault._queries.AddDevice(ctx, deviceQueryParams) + // Decrypt the ciphertext using the nonce. + plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) if err != nil { - return "", err + return nil, err } - return "", nil + return plaintext, nil } -// WalletExists verifies if a wallet already exists -func (vault *Vault) WalletExists(ctx context.Context, userId string) error { - _, err := vault._queries.GetUserByForeignKey(ctx, userId) - return err +// Generate random bytes using crypto/rand, which is secure for cryptographic purposes. +func generateRandomBytes(size int) ([]byte, error) { + bytes := make([]byte, size) + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + return nil, err + } + return bytes, nil } diff --git a/test/integration/tss_test.go b/test/integration/tss_test.go index 4e68c47..3336985 100644 --- a/test/integration/tss_test.go +++ b/test/integration/tss_test.go @@ -256,37 +256,28 @@ func dkgTestProcess(parameters map[string]string) (*tss.DkgResult, *tss.DkgResul log.Println("client.Dkg with host:", host, " and authData:", authData) log.Printf("%q", host) - dkgResultClient, _, err := client.Dkg(host, authData) // update to test ret value + dkgResultClient, clientKeyClient, err := client.Dkg(host, authData) // update to test ret value if err != nil { log.Println("Error client.Dkg:", err) return nil, nil, err } + log.Println("clientKeyClient:", clientKeyClient) + // // debug : leave time to manually check db status // log.Println("wallet stored, check in db !") // time.Sleep(2 * time.Minute) time.Sleep(1 * time.Second) // Give it 1 second to make sure it's in DB. Sometimes the test fails because of race conditions. - res, err := queries.GetUserSigningParameters(context.Background(), parameters["userIdUsed"]) - if err != nil { - log.Println("Error getting user signing parameters:", err) - return nil, nil, err - } + ctx = context.WithValue(ctx, types.ContextKey("metadata"), clientKeyClient) - var signingParamsServer tss.SigningParameters - err = json.Unmarshal(res.Params.RawMessage, &signingParamsServer) + dkgResultServer, err := _server.Vault().RetrieveWallet(ctx, parameters["userIdStored"]) if err != nil { + log.Println("Error retrieveWallet:", err) return nil, nil, err } - dkgResultServer := &tss.DkgResult{ - Pubkey: signingParamsServer.Pubkey, - BKs: signingParamsServer.BKs, - Share: res.Share.String, - Address: res.PublicAddress.String, - } - return dkgResultClient, dkgResultServer, nil } @@ -316,6 +307,8 @@ func signingTestProcess(parameters map[string]string) (*tss.Signature, error) { _server := server.NewServer(vault, &config, nil, logging) // _server.Start() // No need to start, we test the handler directly + var metadata string + // Insert wallet in DB (if required) if len(parameters["dkgResultServerStr"]) > 0 { var dkgResultServer tss.DkgResult @@ -327,7 +320,7 @@ func signingTestProcess(parameters map[string]string) (*tss.Signature, error) { log.Printf("dkgResultServer: %+v\n", dkgResultServer) - _, err = _server.Vault().StoreWallet(ctx, parameters["userAgent"], parameters["userIdStored"], &dkgResultServer) + metadata, err = _server.Vault().StoreWallet(ctx, parameters["userAgent"], parameters["userIdStored"], &dkgResultServer) if err != nil { return nil, err } @@ -346,7 +339,7 @@ func signingTestProcess(parameters map[string]string) (*tss.Signature, error) { log.Println("client.Sign with host:", host, " and authData:", authData) log.Printf("%q", host) - signature, err := client.Sign(host, []byte("test"), parameters["dkgResultClientStr"], "", authData) + signature, err := client.Sign(host, []byte("test"), parameters["dkgResultClientStr"], metadata, authData) if err != nil { return nil, err } diff --git a/test/integration/wallet_test.go b/test/integration/wallet_test.go index c191e88..34cc3b3 100644 --- a/test/integration/wallet_test.go +++ b/test/integration/wallet_test.go @@ -50,12 +50,14 @@ func TestStoreAndRetrieveWallet(t *testing.T) { testDescription = "test 1 (happy case)" successful := true - _, err = _server.Vault().StoreWallet(ctx, "userAgent", "my-user-id-retrieve-one", &dkgResult) + metadata, err := _server.Vault().StoreWallet(ctx, "userAgent", "my-user-id-retrieve-one", &dkgResult) if err != nil { successful = false t.Errorf("Failed "+testDescription+": could not store dkgResult: %+v\n", dkgResult) } + ctx = context.WithValue(ctx, types.ContextKey("metadata"), metadata) + dkgResultRetrieved, err = _server.Vault().RetrieveWallet(ctx, "my-user-id-retrieve-one") if err != nil { successful = false diff --git a/utils/types/context.go b/utils/types/context.go new file mode 100644 index 0000000..74fd3a9 --- /dev/null +++ b/utils/types/context.go @@ -0,0 +1,3 @@ +package types + +type ContextKey string