Skip to content

Commit

Permalink
Disconnect banned users
Browse files Browse the repository at this point in the history
  • Loading branch information
sesposito committed Mar 20, 2023
1 parent 9fadf79 commit 3954113
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
- Token and credentials as inputs on unlink operations are now optional.
- Improve runtime IAP operation errors to include provider payload in error message.
- Build with Go 1.19.2 release.
- Disconnect users when they are banned from the console or runtime functions.

### Fixed
- Observe the error if returned in storage list errors in JavaScript runtime.
Expand Down
2 changes: 1 addition & 1 deletion server/console_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (s *ConsoleServer) BanAccount(ctx context.Context, in *console.AccountId) (
return nil, status.Error(codes.InvalidArgument, "Cannot ban the system user.")
}

if err := BanUsers(ctx, s.logger, s.db, s.sessionCache, []uuid.UUID{userID}); err != nil {
if err := BanUsers(ctx, s.logger, s.db, s.config, s.sessionCache, s.sessionRegistry, s.tracker, []uuid.UUID{userID}); err != nil {
// Error logged in the core function above.
return nil, status.Error(codes.Internal, "An error occurred while trying to ban the user.")
}
Expand Down
14 changes: 12 additions & 2 deletions server/core_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func DeleteUser(ctx context.Context, tx *sql.Tx, userID uuid.UUID) (int64, error
return res.RowsAffected()
}

func BanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, sessionCache SessionCache, ids []uuid.UUID) error {
func BanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, sessionRegistry SessionRegistry, tracker Tracker, ids []uuid.UUID) error {
statements := make([]string, 0, len(ids))
params := make([]interface{}, 0, len(ids))
for i, id := range ids {
Expand All @@ -194,7 +194,17 @@ func BanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, sessionCache
return err
}

sessionCache.Ban(ids)
for _, id := range ids {
// Logout and disconnect.
if err = SessionLogout(config, sessionCache, id, "", ""); err != nil {
return err
}
for _, presence := range tracker.ListPresenceIDByStream(PresenceStream{Mode: StreamModeNotifications, Subject: id}) {
if err = sessionRegistry.Disconnect(ctx, presence.SessionID); err != nil {
return err
}
}
}

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion server/runtime_go_nakama.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func (n *RuntimeGoNakamaModule) UsersBanId(ctx context.Context, userIDs []string
ids = append(ids, id)
}

return BanUsers(ctx, n.logger, n.db, n.sessionCache, ids)
return BanUsers(ctx, n.logger, n.db, n.config, n.sessionCache, n.sessionRegistry, n.tracker, ids)
}

// @group users
Expand Down
2 changes: 1 addition & 1 deletion server/runtime_javascript_nakama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ func (n *runtimeJavascriptNakamaModule) usersBanId(r *goja.Runtime) func(goja.Fu
userIDs = append(userIDs, uid)
}

err := BanUsers(n.ctx, n.logger, n.db, n.sessionCache, userIDs)
err := BanUsers(n.ctx, n.logger, n.db, n.config, n.sessionCache, n.sessionRegistry, n.tracker, userIDs)
if err != nil {
panic(r.NewGoError(fmt.Errorf("failed to ban users: %s", err.Error())))
}
Expand Down
2 changes: 1 addition & 1 deletion server/runtime_lua_nakama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2837,7 +2837,7 @@ func (n *RuntimeLuaNakamaModule) usersBanId(l *lua.LState) int {
}

// Ban the user accounts.
err := BanUsers(l.Context(), n.logger, n.db, n.sessionCache, uids)
err := BanUsers(l.Context(), n.logger, n.db, n.config, n.sessionCache, n.sessionRegistry, n.tracker, uids)
if err != nil {
l.RaiseError(fmt.Sprintf("failed to ban users: %s", err.Error()))
return 0
Expand Down

0 comments on commit 3954113

Please sign in to comment.