diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index e3aff93ae17a..eabdcd466d3a 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -176,7 +176,34 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: 400, "expiry_time must not be in the past", Codes.INVALID_PARAM ) - await self.store.create_registration_token(token, uses_allowed, expiry_time) + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + + if "token" not in body: + # The token was generated. If it could not be created because + # that token already exists, then try a few more times before + # reporting a failure. + i = 0 + while not created and i < 3: + token = "".join(random.choices(self.allowed_chars, k=length)) + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + i += 1 + if not created: + raise SynapseError( + 500, + "The generated token already exists. Try again with a greater length.", + Codes.UNKNOWN, + ) + + elif not created: + # The token was specified in the request, but it already exists + # so could not be created. + raise SynapseError( + 400, f"Token already exists: {token}", Codes.INVALID_PARAM + ) resp = { "token": token, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 4f09ba649cb2..b7df9bce22b8 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1328,7 +1328,7 @@ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any async def create_registration_token( self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] - ) -> None: + ) -> bool: """Create a new registration token. Used by the admin API. Args: @@ -1354,9 +1354,8 @@ def _create_registration_token_txn(txn): ) if row is not None: - raise SynapseError( - 400, f"Token already exists: {token}", Codes.INVALID_PARAM - ) + # Token already exists + return False self.db_pool.simple_insert_txn( txn, @@ -1370,6 +1369,8 @@ def _create_registration_token_txn(txn): }, ) + return True + return await self.db_pool.runInteraction( "create_registration_token", _create_registration_token_txn )