Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Test OIDC login with multiple OIDC providers #9127

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9127.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.
2 changes: 1 addition & 1 deletion synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, hs: "HomeServer"):
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}
} # type: Dict[str, OidcProvider]

async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Expand Down
94 changes: 62 additions & 32 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

import time
import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Union
from urllib.parse import parse_qs, urlencode, urlparse

from mock import Mock
Expand All @@ -38,6 +37,7 @@
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless

try:
Expand Down Expand Up @@ -389,13 +389,35 @@ def default_config(self) -> Dict[str, Any]:
},
}

# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG

# additional OIDC providers
config["oidc_providers"] = [
{
"idp_id": "idp1",
"idp_name": "IDP1",
"discover": False,
"issuer": "https://issuer1",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://issuer1/auth",
"token_endpoint": "https://issuer1/token",
"userinfo_endpoint": "https://issuer1/userinfo",
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.sub }}"}
},
}
]
return config

def create_resource_dict(self) -> Dict[str, Resource]:
from synapse.rest.oidc import OIDCResource

d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d

def test_multi_sso_redirect(self):
Expand All @@ -415,36 +437,11 @@ def test_multi_sso_redirect(self):
self.assertEqual(channel.code, 200, channel.result)

# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()

# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]

# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]

def error(_, message):
self.fail(message)

p = FormPageParser()
p = TestHtmlParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()

self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])

self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)

Expand Down Expand Up @@ -494,14 +491,15 @@ def test_multi_sso_redirect_to_saml(self):
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)

def test_multi_sso_redirect_to_oidc(self):
def test_login_via_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"
client_redirect_url = 'https://x?"q"="foo"'

# pick the default OIDC provider
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ urllib.parse.quote_plus(client_redirect_url)
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
Expand All @@ -524,6 +522,38 @@ def test_multi_sso_redirect_to_oidc(self):
client_redirect_url,
)

channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})

# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
)
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()

# ... which should contain our redirect link
self.assertEqual(len(p.links), 1)
path, query = p.links[0].split("?", 1)
self.assertEqual(path, "https://x")

# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:1], [('"q"', '"foo"')])
self.assertEqual(params[1][0], "loginToken")

# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
login_token = params[1][1]
chan = self.make_request(
"POST", "/login", content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")

def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
Expand Down
62 changes: 36 additions & 26 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import re
import time
import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
from typing import Any, Dict, Mapping, MutableMapping, Optional

from mock import patch

Expand All @@ -35,6 +34,7 @@

from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser


@attr.s
Expand Down Expand Up @@ -440,10 +440,36 @@ def auth_via_oidc(
# param that synapse passes to the IdP via query params, as well as the cookie
# that synapse passes to the client.

oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
oauth_uri_path, _ = oauth_uri.split("?", 1)
assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
"unexpected SSO URI " + oauth_uri_path
)
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)

def complete_oidc_auth(
self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow

Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
sent back to the OIDC provider.

Requires the OIDC callback resource to be mounted at the normal place.

Args:
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.

Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
Expand All @@ -456,9 +482,9 @@ def auth_via_oidc(
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}),
(TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", user_info_dict),
(TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]

async def mock_req(method: str, uri: str, data=None, headers=None):
Expand Down Expand Up @@ -542,25 +568,7 @@ def initiate_sso_ui_auth(
channel.extract_cookies(cookies)

# parse the confirmation page to fish out the link.
class ConfirmationPageParser(HTMLParser):
def __init__(self):
super().__init__()

self.links = [] # type: List[str]

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)

def error(_, message):
raise AssertionError(message)

p = ConfirmationPageParser()
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
assert len(p.links) == 1, "not exactly one link in confirmation page"
Expand All @@ -570,6 +578,8 @@ def error(_, message):

# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
Expand All @@ -578,7 +588,7 @@ def error(_, message):
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
"userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}
53 changes: 53 additions & 0 deletions tests/test_utils/html_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from html.parser import HTMLParser
from typing import Dict, Iterable, List, Optional, Tuple


class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""

def __init__(self):
super().__init__()

# a list of links found in the doc
self.links = [] # type: List[str]

# the values of any hidden <input>s: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]

# the values of any radio buttons: map from name to list of values
self.radios = {} # type: Dict[str, List[Optional[str]]]

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)
elif tag == "input":
input_name = attr_dict.get("name")
if attr_dict["type"] == "radio":
assert input_name
self.radios.setdefault(input_name, []).append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
assert input_name
self.hiddens[input_name] = attr_dict["value"]

def error(_, message):
raise AssertionError(message)