Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Split out a separate endpoint to complete SSO registration #9262

Merged
merged 5 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9262.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the user experience of setting up an account via single-sign on.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obviously, there is more to follow.

2 changes: 2 additions & 0 deletions synapse/app/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
Expand Down Expand Up @@ -192,6 +193,7 @@ def _configure_named_resource(self, name, compress=False):
"/_synapse/admin": AdminRestResource(self),
"/_synapse/client/pick_username": pick_username_resource(self),
"/_synapse/client/pick_idp": PickIdpResource(self),
"/_synapse/client/sso_register": SsoRegisterResource(self),
}
)

Expand Down
81 changes: 66 additions & 15 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from typing_extensions import NoReturn, Protocol

from twisted.web.http import Request
from twisted.web.iweb import IRequest

from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -141,6 +142,9 @@ class UsernameMappingSession:
# expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int)

# choices made by the user
chosen_localpart = attr.ib(type=Optional[str], default=None)


# the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
Expand Down Expand Up @@ -647,6 +651,25 @@ async def complete_sso_ui_auth_request(
)
respond_with_html(request, 200, html)

def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
"""Look up the given username mapping session

If it is not found, raises a SynapseError with an http code of 400

Args:
session_id: session to look up
Returns:
active mapping session
Raises:
SynapseError if the session is not found/has expired
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if session:
return session
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")

async def check_username_availability(
self, localpart: str, session_id: str,
) -> bool:
Expand All @@ -663,12 +686,7 @@ async def check_username_availability(

# make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts

self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
self.get_mapping_session(session_id)

logger.info(
"[session %s] Checking for availability of username %s",
Expand Down Expand Up @@ -696,16 +714,33 @@ async def handle_submit_username_request(
localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
session = self.get_mapping_session(session_id)

# update the session with the user's choices
session.chosen_localpart = localpart

# we're done; now we can register the user
respond_with_redirect(request, b"/_synapse/client/sso_register")

async def register_sso_user(self, request: Request, session_id: str) -> None:
"""Called once we have all the info we need to register a new user.

logger.info("[session %s] Registering localpart %s", session_id, localpart)
Does so and serves an HTTP response

Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
"""
session = self.get_mapping_session(session_id)

logger.info(
"[session %s] Registering localpart %s",
session_id,
session.chosen_localpart,
)

attributes = UserAttributes(
localpart=localpart,
localpart=session.chosen_localpart,
display_name=session.display_name,
emails=session.emails,
)
Expand All @@ -720,7 +755,12 @@ async def handle_submit_username_request(
request.getClientIP(),
)

logger.info("[session %s] Registered userid %s", session_id, user_id)
logger.info(
"[session %s] Registered userid %s with attributes %s",
session_id,
user_id,
attributes,
)

# delete the mapping session and the cookie
del self._username_mapping_sessions[session_id]
Expand Down Expand Up @@ -751,3 +791,14 @@ def _expire_old_sessions(self):
for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]


def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie

Raises a SynapseError if the cookie isn't found
"""
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")
7 changes: 7 additions & 0 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")


def respond_with_redirect(request: Request, url: bytes) -> None:
"""Write a 302 response to the request, if it is still alive."""
logger.debug("Redirect to %s", url.decode("utf-8"))
request.redirect(url)
finish_request(request)


def finish_request(request: Request):
""" Finish writing the response to the request.

Expand Down
16 changes: 6 additions & 10 deletions synapse/rest/synapse/client/pick_username.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

import pkg_resources
Expand All @@ -20,8 +21,7 @@
from twisted.web.resource import Resource
from twisted.web.static import File

from synapse.api.errors import SynapseError
from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
Expand Down Expand Up @@ -61,12 +61,10 @@ def __init__(self, hs: "HomeServer"):
async def _async_render_GET(self, request: Request):
localpart = parse_string(request, "username", required=True)

session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
session_id = get_username_mapping_session_cookie_from_request(request)

is_available = await self._sso_handler.check_username_availability(
localpart, session_id.decode("ascii", errors="replace")
localpart, session_id
)
return 200, {"available": is_available}

Expand All @@ -79,10 +77,8 @@ def __init__(self, hs: "HomeServer"):
async def _async_render_POST(self, request: SynapseRequest):
localpart = parse_string(request, "username", required=True)

session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
session_id = get_username_mapping_session_cookie_from_request(request)

await self._sso_handler.handle_submit_username_request(
request, localpart, session_id.decode("ascii", errors="replace")
request, localpart, session_id
)
50 changes: 50 additions & 0 deletions synapse/rest/synapse/client/sso_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from twisted.web.http import Request

from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.http.server import DirectServeHtmlResource

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class SsoRegisterResource(DirectServeHtmlResource):
"""A resource which completes SSO registration

This resource gets mounted at /_synapse/client/sso_register, and is shown
after we collect username and/or consent for a new SSO user. It (finally) registers
the user, and confirms redirect to the client
"""

def __init__(self, hs: "HomeServer"):
super().__init__()
self._sso_handler = hs.get_sso_handler()

async def _async_render_GET(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
logger.warning("Error fetching session cookie: %s", e)
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
richvdh marked this conversation as resolved.
Show resolved Hide resolved
return
await self._sso_handler.register_sso_user(request, session_id)
14 changes: 13 additions & 1 deletion tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
from synapse.types import create_requester

from tests import unittest
Expand Down Expand Up @@ -1215,6 +1216,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:

d = super().create_resource_dict()
d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d

Expand Down Expand Up @@ -1253,7 +1255,7 @@ def test_username_picker(self):
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)

# Now, submit a username to the username picker, which should serve a redirect
# back to the client
# to the completion page
submit_path = picker_url + "/submit"
content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request(
Expand All @@ -1270,6 +1272,16 @@ def test_username_picker(self):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")

# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
"GET",
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")

# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")
Expand Down