Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix guest user registration with lots of client readers #7866

Merged
merged 5 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7866.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.
12 changes: 11 additions & 1 deletion scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import (
)
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
Expand Down Expand Up @@ -622,8 +623,10 @@ class Porter(object):
)
)

# Step 5. Do final post-processing
# Step 5. Set up sequences
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()

self.progress.done()
except Exception as e:
Expand Down Expand Up @@ -793,6 +796,13 @@ class Porter(object):

return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)

def _setup_user_id_seq(self):
def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))

return self.postgres_store.db.runInteraction("setup_user_id_seq", r)


##############################################
# The following is simply UI stuff
Expand Down
22 changes: 1 addition & 21 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer

from ._base import BaseHandler

Expand All @@ -50,14 +49,7 @@ def __init__(self, hs):
self.user_directory_handler = hs.get_user_directory_handler()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()

self._next_generated_user_id = None

self.macaroon_gen = hs.get_macaroon_generator()

self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid

if hs.config.worker_app:
Expand Down Expand Up @@ -219,7 +211,7 @@ async def register_user(
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")

localpart = await self._generate_user_id()
localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
Expand Down Expand Up @@ -510,18 +502,6 @@ def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=Non
errcode=Codes.EXCLUSIVE,
)

async def _generate_user_id(self):
if self._next_generated_user_id is None:
with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
await self.store.find_next_generated_user_id_localpart()
)

id = self._next_generated_user_id
self._next_generated_user_id += 1
return str(id)

def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Expand Down
65 changes: 36 additions & 29 deletions synapse/storage/data_stores/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks

Expand All @@ -42,6 +44,10 @@ def __init__(self, database: Database, db_conn, hs):
self.config = hs.config
self.clock = hs.get_clock()

self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)

@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
Expand Down Expand Up @@ -481,39 +487,17 @@ def _count_users(txn):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret

@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user

Generated user IDs are integers, so we find the largest integer user ID
already taken and return that plus one.
Returns: a (hopefully) free localpart
"""

def _find_next_generated_user_id(txn):
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")

regex = re.compile(r"^@(\d+):")

max_found = 0

for (user_id,) in txn:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)

return max_found + 1

return (
(
yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
next_id = await self.db.runInteraction(
"generate_user_id", self._user_id_seq.get_next_id_txn
)

return str(next_id)

async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid

Expand Down Expand Up @@ -1573,3 +1557,26 @@ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)


def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Gets the localpart of the max current generated user ID.

Generated user IDs are integers, so we find the largest integer user ID
already taken and return that.
"""

# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")

regex = re.compile(r"^@(\d+):")

max_found = 0

for (user_id,) in cur:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)
return max_found
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adds a postgres SEQUENCE for generating guest user IDs.
"""

from synapse.storage.data_stores.main.registration import (
find_max_generated_user_id_localpart,
)
from synapse.storage.engines import PostgresEngine


def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return

next_id = find_max_generated_user_id_localpart(cur) + 1
cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))


def run_upgrade(*args, **kwargs):
pass
12 changes: 11 additions & 1 deletion synapse/storage/data_stores/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
Expand Down Expand Up @@ -92,6 +94,14 @@ def __init__(self, database: Database, db_conn, hs):
"*stateGroupMembersCache*", 500000,
)

def get_max_state_group_txn(txn: Cursor):
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0]

self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)

@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
Expand Down Expand Up @@ -386,7 +396,7 @@ def _store_state_group_txn(txn):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")

state_group = self.database_engine.get_next_state_group_id(txn)
state_group = self._state_group_seq_gen.get_next_id_txn(txn)

self.db.simple_insert_txn(
txn,
Expand Down
6 changes: 0 additions & 6 deletions synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,6 @@ def is_connection_closed(self, conn: ConnectionType) -> bool:
def lock_table(self, txn, table: str) -> None:
...

@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...

@property
@abc.abstractmethod
def server_version(self) -> str:
Expand Down
6 changes: 0 additions & 6 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,6 @@ def is_connection_closed(self, conn):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))

def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]

@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'
Expand Down
13 changes: 0 additions & 13 deletions synapse/storage/engines/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,6 @@ def is_connection_closed(self, conn):
def lock_table(self, txn, table):
return

def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]

self._current_state_group_id += 1
return self._current_state_group_id

@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing_extensions import Deque

from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator


class IdGenerator(object):
Expand Down Expand Up @@ -247,7 +248,6 @@ def __init__(
):
self._db = db
self._instance_name = instance_name
self._sequence_name = sequence_name

# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
Expand All @@ -260,6 +260,8 @@ def __init__(
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
Expand All @@ -283,9 +285,7 @@ def _load_current_ids(
return current_positions

def _load_next_id_txn(self, txn):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
(next_id,) = txn.fetchone()
return next_id
return self._sequence_gen.get_next_id_txn(txn)

async def get_next(self):
"""
Expand Down
Loading