Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Abstract shared SSO code #8765

Merged
merged 7 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 29 additions & 49 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from twisted.web.client import readBody

from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
Expand Down Expand Up @@ -83,17 +84,12 @@ def __str__(self):
return self.error


class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""


class OidcHandler:
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""

def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
Expand All @@ -120,36 +116,13 @@ def __init__(self, hs: "HomeServer"):
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template

# identifier for the external_ids table
self._auth_provider_id = "oidc"

def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.

This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.

Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()

def _validate_metadata(self):
"""Verifies the provider metadata.
Expand Down Expand Up @@ -571,7 +544,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:

Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
``self._sso_handler.render_error`` which displays an HTML page for the error.

Most of the OpenID Connect logic happens here:

Expand Down Expand Up @@ -609,7 +582,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)

self._render_error(request, error, description)
self._sso_handler.render_error(request, error, description)
return

# otherwise, it is presumably a successful response. see:
Expand All @@ -619,7 +592,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return

# Remove the cookie. There is a good chance that if the callback failed
Expand All @@ -637,7 +612,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return

state = request.args[b"state"][0].decode()
Expand All @@ -651,17 +628,19 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
self._sso_handler.render_error(request, "mismatching_session", str(e))
return

# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return

logger.debug("Exchanging code")
Expand All @@ -670,7 +649,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
self._sso_handler.render_error(request, e.error, e.error_description)
return

logger.debug("Successfully obtained OAuth2 access token")
Expand All @@ -683,15 +662,15 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
self._sso_handler.render_error(request, "invalid_token", str(e))
return

# Pull out the user-agent and IP from the request.
Expand All @@ -705,7 +684,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return

# Mapping providers might not have get_extra_attributes: only call this
Expand Down Expand Up @@ -770,7 +749,7 @@ def _generate_oidc_session_token(
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))

Expand Down Expand Up @@ -845,7 +824,7 @@ def _verify_expiry(self, caveat: str) -> bool:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
now = self.clock.time_msec()
return now < expiry

async def _map_userinfo_to_user(
Expand Down Expand Up @@ -891,7 +870,7 @@ async def _map_userinfo_to_user(
remote_user_id,
)

registered_user_id = await self._datastore.get_user_by_external_id(
registered_user_id = await self.store.get_user_by_external_id(
self._auth_provider_id, remote_user_id,
)

Expand All @@ -917,8 +896,8 @@ async def _map_userinfo_to_user(

localpart = map_username_to_mxid_localpart(attributes["localpart"])

user_id = UserID(localpart, self._hostname).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
Expand All @@ -942,7 +921,8 @@ async def _map_userinfo_to_user(
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
Expand Down
Loading