Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor the CAS code to move the logic to a handler #7136

Merged
merged 7 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7136.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactored the CAS authentication logic to a separate class.
177 changes: 177 additions & 0 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import xml.etree.ElementTree as ET
from typing import Awaitable, Dict, Optional, Tuple

from six.moves import urllib

from twisted.web.client import PartialDownloadError

from synapse.api.errors import Codes, LoginError
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart

logger = logging.getLogger(__name__)


class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.

Args:
hs (synapse.server.HomeServer)
"""

def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()

self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas_required_attributes

self._http_client = hs.get_proxied_http_client()

# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)

def _build_service_param(self, client_redirect_url: str) -> str:
return "%s%s?redirectUrl=%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.quote(client_redirect_url, safe=""),
)

def _handle_cas_response(
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
) -> Awaitable:
clokep marked this conversation as resolved.
Show resolved Hide resolved
user, attributes = self._parse_cas_response(cas_response_body)
displayname = attributes.pop(self._cas_displayname_attribute, None)

for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

return self._on_successful_auth(user, request, client_redirect_url, displayname)

def _parse_cas_response(
self, cas_response_body: str
) -> Tuple[str, Dict[str, Optional[str]]]:
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes

async def _on_successful_auth(
self,
username: str,
request: SynapseRequest,
client_redirect_url: str,
user_display_name: Optional[str] = None,
) -> None:
"""Called once the user has successfully authenticated with the SSO.

Registers the user if necessary, and then returns a redirect (with
a login token) to the client.

Args:
username: the remote user id. We'll map this onto
something sane for a MXID localpath.

request: the incoming request from the browser. We'll
respond to it with a redirect.

client_redirect_url: the redirect_url the client gave us when
it first started the process.

user_display_name: if set, and we have to register a new user,
we will set their displayname to this.

Returns:
Completes once we have handled the request.
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)

self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)

def handle_redirect_request(self, client_redirect_url: str) -> str:
args = urllib.parse.urlencode(
{"service": self._build_service_param(client_redirect_url)}
)

return "%s/login?%s" % (self._cas_server_url, args)

async def handle_ticket_request(self, request: SynapseRequest):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self._cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self._build_service_param(client_redirect_url),
}
clokep marked this conversation as resolved.
Show resolved Hide resolved
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response

await self._handle_cas_response(request, body, client_redirect_url)
clokep marked this conversation as resolved.
Show resolved Hide resolved
160 changes: 6 additions & 154 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,14 @@
# limitations under the License.

import logging
import xml.etree.ElementTree as ET

from six.moves import urllib

from twisted.web.client import PartialDownloadError

from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_string,
)
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,14 +63,6 @@ def login_id_thirdparty_from_phone(identifier):
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}


def build_service_param(cas_service_url, client_redirect_url):
return "%s%s?redirectUrl=%s" % (
cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.quote(client_redirect_url, safe=""),
)


class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
Expand Down Expand Up @@ -434,97 +417,21 @@ def get_sso_url(self, client_redirect_url):

class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self._cas_handler = hs.get_cas_handler()

def get_sso_url(self, client_redirect_url):
args = urllib.parse.urlencode(
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
)

return "%s/login?%s" % (self.cas_server_url, args)
return self._cas_handler.handle_redirect_request(client_redirect_url)


class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)

def __init__(self, hs):
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
self._cas_handler = hs.get_cas_handler()

async def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": build_service_param(self.cas_service_url, client_redirect_url),
}
try:
body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = await self.handle_cas_response(request, body, client_redirect_url)
return result

def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)

for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url, displayname
)

def parse_cas_response(self, cas_response_body):
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
return await self._cas_handler.handle_ticket_request(request)


class SAMLRedirectServlet(BaseSSORedirectServlet):
Expand All @@ -537,61 +444,6 @@ def get_sso_url(self, client_redirect_url):
return self._saml_handler.handle_redirect_request(client_redirect_url)


class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service

Args:
hs (synapse.server.HomeServer)
"""

def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()

# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)

async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.

Registers the user if necessary, and then returns a redirect (with
a login token) to the client.

Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.

request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.

client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.

user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.

Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)

self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)


def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
Expand Down
Loading