Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Refactor User JWT group sync #2690

Merged
merged 11 commits into from
Oct 4, 2024
284 changes: 184 additions & 100 deletions management/server/account.go

Large diffs are not rendered by default.

169 changes: 113 additions & 56 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")

initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
Expand Down Expand Up @@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"

initAccount := newAccountWithId(context.Background(), "", userId, domain)
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID := initAccount.Id
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")

claims := jwtclaims.AuthorizationClaims{
Expand Down Expand Up @@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
}

func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
Expand All @@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {

userId := "test_user"

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}

_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")

_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
t.Errorf("expected an error when user ID is empty")
}
}

Expand Down Expand Up @@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")

settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
Expand All @@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand All @@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
Expand Down Expand Up @@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand Down Expand Up @@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}

accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
Expand All @@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand All @@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
Expand Down Expand Up @@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")

updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
Expand All @@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)

accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")

settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")

Expand Down Expand Up @@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}

func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

// create a new account
account := &Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
Expand All @@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
},
Settings: &Settings{GroupsPropagationEnabled: true},
Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{
"user1": {Id: "user1"},
"user2": {Id: "user2"},
"user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2", AccountID: "accountID"},
},
}

assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")

t.Run("empty jwt groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.False(t, updated, "account should not be updated")
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})

t.Run("jwt match existing api group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated")
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups))
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)

group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})

t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))

updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated")
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups))
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)

group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})

t.Run("add jwt group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 2, "new group should be added")
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})

t.Run("existed group not update", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group2"})
assert.False(t, updated, "account should not be updated")
assert.Len(t, account.Groups, 2, "groups count should not be changed")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})

t.Run("add new group", func(t *testing.T) {
updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 3, "new group should be added")
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})

t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
})
}

Expand Down
12 changes: 6 additions & 6 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
Expand Down Expand Up @@ -194,14 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}

// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
}
return "", status.Errorf(
codes.Unimplemented,
"method GetAccountIDByUserOrAccountID is not implemented",
"method GetAccountIDByUserID is not implemented",
)
}

Expand Down
Loading
Loading