diff --git a/go.mod b/go.mod index 21bd0e9ab..41f241729 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/DataDog/zstd v1.5.6 // indirect diff --git a/go.sum b/go.sum index 05953d790..41524fd3a 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 h1:JZg6HRh6W6U4OLl6lk7BZ7BLisIzM9dG1R50zUk9C/M= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0/go.mod h1:YL1xnZ6QejvQHWJrX/AvhFl4WW4rqHVoKspWNVwFk0M= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.1.0 h1:c726lgbwpwFBuj+Fyrwuh/vUilqFo+hUAOUNjsKj5DI= diff --git a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go index 3672e7cb9..3260ab95c 100644 --- a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go +++ b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go @@ -13,7 +13,9 @@ import ( "github.com/ten-protocol/go-ten/go/common/viewingkey" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" + "github.com/ten-protocol/go-ten/tools/walletextension/common" "github.com/ten-protocol/go-ten/tools/walletextension/encryption" ) @@ -55,6 +57,12 @@ const ( USERS_CONTAINER_NAME = "users" ) +// userWithETag struct is used to store the user data along with its ETag +type userWithETag struct { + user dbcommon.GWUserDB + etag azcore.ETag +} + func NewCosmosDB(connectionString string, encryptionKey []byte) (*CosmosDB, error) { // Create encryptor encryptor, err := encryption.NewEncryptor(encryptionKey) @@ -126,7 +134,7 @@ func (c *CosmosDB) AddSessionKey(userID []byte, key common.GWSessionKey) error { if err != nil { return fmt.Errorf("failed to get user: %w", err) } - user.SessionKey = &dbcommon.GWSessionKeyDB{ + user.user.SessionKey = &dbcommon.GWSessionKeyDB{ PrivateKey: crypto.FromECDSA(key.PrivateKey.ExportECDSA()), Account: dbcommon.GWAccountDB{ AccountAddress: key.Account.Address.Bytes(), @@ -134,7 +142,7 @@ func (c *CosmosDB) AddSessionKey(userID []byte, key common.GWSessionKey) error { SignatureType: int(key.Account.SignatureType), }, } - return c.updateUser(ctx, user) + return c.updateUser(ctx, user.user) } func (c *CosmosDB) ActivateSessionKey(userID []byte, active bool) error { @@ -144,8 +152,8 @@ func (c *CosmosDB) ActivateSessionKey(userID []byte, active bool) error { if err != nil { return fmt.Errorf("failed to get user: %w", err) } - user.ActiveSK = active - return c.updateUser(ctx, user) + user.user.ActiveSK = active + return c.updateUser(ctx, user.user) } func (c *CosmosDB) RemoveSessionKey(userID []byte) error { @@ -155,8 +163,8 @@ func (c *CosmosDB) RemoveSessionKey(userID []byte) error { if err != nil { return fmt.Errorf("failed to get user: %w", err) } - user.SessionKey = nil - return c.updateUser(ctx, user) + user.user.SessionKey = nil + return c.updateUser(ctx, user.user) } func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { @@ -173,9 +181,9 @@ func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature [] Signature: signature, SignatureType: int(signatureType), } - user.Accounts = append(user.Accounts, newAccount) + user.user.Accounts = append(user.user.Accounts, newAccount) - return c.updateUser(ctx, user) + return c.updateUser(ctx, user.user) } func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) { @@ -183,51 +191,63 @@ func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) { if err != nil { return nil, err } - return user.ToGWUser() + return user.user.ToGWUser() } -func (c *CosmosDB) getUserDB(userID []byte) (dbcommon.GWUserDB, error) { +func (c *CosmosDB) getUserDB(userID []byte) (userWithETag, error) { keyString, partitionKey := c.dbKey(userID) ctx := context.Background() itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) if err != nil { - return dbcommon.GWUserDB{}, err + return userWithETag{}, err } var doc EncryptedDocument err = json.Unmarshal(itemResponse.Value, &doc) if err != nil { - return dbcommon.GWUserDB{}, fmt.Errorf("failed to unmarshal document: %w", err) + return userWithETag{}, fmt.Errorf("failed to unmarshal document: %w", err) } data, err := c.encryptor.Decrypt(doc.Data) if err != nil { - return dbcommon.GWUserDB{}, fmt.Errorf("failed to decrypt data: %w", err) + return userWithETag{}, fmt.Errorf("failed to decrypt data: %w", err) } var user dbcommon.GWUserDB err = json.Unmarshal(data, &user) if err != nil { - return dbcommon.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) + return userWithETag{}, fmt.Errorf("failed to unmarshal user data: %w", err) } - return user, nil + return userWithETag{user: user, etag: itemResponse.ETag}, nil } func (c *CosmosDB) updateUser(ctx context.Context, user dbcommon.GWUserDB) error { - keyString, partitionKey := c.dbKey(user.UserId) + // Attempt to update without retries + currentUser, err := c.getUserDB(user.UserId) + if err != nil { + return fmt.Errorf("failed to get current user state: %w", err) + } + keyString, partitionKey := c.dbKey(user.UserId) encryptedDoc, err := c.createEncryptedDoc(user, keyString) if err != nil { return fmt.Errorf("failed to marshal updated document: %w", err) } - // Replace the item in the container - _, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, encryptedDoc, nil) + options := &azcosmos.ItemOptions{ + IfMatchEtag: ¤tUser.etag, + } + + _, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, encryptedDoc, options) if err != nil { - return fmt.Errorf("failed to update user with new account: %w", err) + if strings.Contains(err.Error(), "Precondition Failed") { + return fmt.Errorf("ETag mismatch: the user document was modified by another process") + } + return fmt.Errorf("failed to update user: %w", err) } + return nil }