Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a race when registering via email 3pid #16827

Draft
wants to merge 13 commits into
base: develop
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions changelog.d/16827.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a race when registering via email 3pid where 2 different user ids would be created.
9 changes: 9 additions & 0 deletions synapse/handlers/worker_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#
#

import logging
import random
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -48,6 +49,8 @@
from synapse.logging.opentracing import opentracing
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


# This lock is used to avoid creating an event while we are purging the room.
# We take a read lock when creating an event, and a write one when purging a room.
Expand Down Expand Up @@ -244,9 +247,14 @@ async def __aenter__(self) -> None:
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
# Let's reset retry interval since we got notified, we
# should only increase it if we hit the previous one
self._retry_interval = 0.1
except Exception:
pass

logger.warn(f"lock taken: {self.lock_name}, {self.lock_key}")

return await self._inner_lock.__aenter__()

async def __aexit__(
Expand All @@ -261,6 +269,7 @@ async def __aexit__(

try:
r = await self._inner_lock.__aexit__(exc_type, exc, tb)
logger.warn(f"lock released: {self.lock_name}, {self.lock_key}")
finally:
self._lock_span.__exit__(exc_type, exc, tb)

Expand Down
22 changes: 21 additions & 1 deletion synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@

logger = logging.getLogger(__name__)

USER_REGISTRATION_LOCK_NAME = "user_registration"


class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
Expand Down Expand Up @@ -425,6 +427,7 @@ def __init__(self, hs: "HomeServer"):
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self.clock = hs.get_clock()
self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
Expand Down Expand Up @@ -516,6 +519,23 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"An access token should not be provided on requests to /register (except if type is m.login.application_service)",
)

# Take a global lock when doing user registration to avoid races,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am only protecting normal user registration with the lock because I think _do_appservice_registration and _do_guest_registration are safe but I am not sure to be honest, especially for the appservice one.

Out of safety we may protect the whole code in on_POST, opinions welcome.

# for example when doing 3pid email binding.
async with self._worker_lock_handler.acquire_lock(
USER_REGISTRATION_LOCK_NAME, ""
):
return await self._do_user_register(
desired_username, client_addr, body, should_issue_refresh_token, request
)

async def _do_user_register(
self,
desired_username: Optional[str],
address: str,
body: JsonDict,
should_issue_refresh_token: bool,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
Expand Down Expand Up @@ -706,7 +726,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
guest_access_token=guest_access_token,
threepid=threepid,
default_display_name=display_name,
address=client_addr,
address=address,
user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
Expand Down
91 changes: 90 additions & 1 deletion tests/rest/client/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
#
import datetime
import os
from typing import Any, Dict, List, Tuple
import re
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import AsyncMock

import pkg_resources
Expand Down Expand Up @@ -1283,3 +1284,91 @@ def test_GET_ratelimiting(self) -> None:
f"{self.url}?token={token}",
)
self.assertEqual(channel.code, 200, msg=channel.result)


class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]

def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super().make_homeserver(reactor, clock)

async def send_email(
email_address: str,
subject: str,
app_name: str,
html: str,
text: str,
additional_headers: Optional[Dict[str, str]] = None,
) -> None:
self.email_attempts.append(text)

self.email_attempts: List[str] = []
hs.get_send_email_handler().send_email = send_email # type: ignore[method-assign]
return hs

@unittest.override_config(
{
"public_baseurl": "https://test_server",
"registrations_require_3pid": ["email"],
"disable_msisdn_registration": True,
"email": {
"smtp_host": "mail_server",
"smtp_port": 2525,
"notif_from": "sender@host",
},
}
)
def test_email_3pid_registration_race(self) -> None:
channel = self.make_request("POST", b"register", {"password": "password"})
session = channel.json_body["session"]

# request a token to be sent by email for validation
channel = self.make_request(
"POST",
b"register/email/requestToken",
{
"client_secret": "client_secret",
"email": "email@email",
"send_attempt": 1,
},
)
sid = channel.json_body["sid"]

email_text = self.email_attempts[0]
match = re.search("https://test_server(.*)", email_text)
assert match is not None
validation_url = match.group(1)

# "Click" the link in the email to validate the adress
self.make_request("GET", validation_url.encode("utf-8"))

# launch 2 simultaneous register request, only one account
# should be created after that.
register_content = {
"auth": {
"session": session,
"threepid_creds": {
"client_secret": "client_secret",
"sid": sid,
},
"type": "m.login.email.identity",
},
"password": "password",
}
register1_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
register2_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
while (
not register1_channel.is_finished() or not register2_channel.is_finished()
):
self.pump()

self.assertEqual(
register1_channel.json_body["user_id"],
register2_channel.json_body["user_id"],
)
Loading