Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Don't throw exception on m.login.jwt automatic user creation #7585

Merged
merged 4 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7585.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof.
15 changes: 8 additions & 7 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def _do_other_login(self, login_submission):
return result

async def _complete_login(
self, user_id, login_submission, callback=None, create_non_existant_users=False
self, user_id, login_submission, callback=None, create_non_existent_users=False
):
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
Expand All @@ -312,7 +312,7 @@ async def _complete_login(
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
create_non_existant_users (bool): Whether to create the user if
create_non_existent_users (bool): Whether to create the user if
they don't exist. Defaults to False.

Returns:
Expand All @@ -331,12 +331,13 @@ async def _complete_login(
update=True,
)

if create_non_existant_users:
user_id = await self.auth_handler.check_user_exists(user_id)
if not user_id:
user_id = await self.registration_handler.register_user(
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
if not canonical_uid:
canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart
)
user_id = canonical_uid

device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
Expand Down Expand Up @@ -391,7 +392,7 @@ async def do_jwt_login(self, login_submission):

user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
user_id, login_submission, create_non_existant_users=True
user_id, login_submission, create_non_existent_users=True
)
return result

Expand Down
153 changes: 153 additions & 0 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import time
import urllib.parse

from mock import Mock

import jwt

import synapse.rest.admin
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
Expand Down Expand Up @@ -473,3 +476,153 @@ def test_deactivated_user(self):
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])


class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]

jwt_secret = "secret"

def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
self.hs.config.jwt_algorithm = "HS256"
return self.hs

def jwt_encode(self, token, secret=jwt_secret):
return jwt.encode(token, secret, "HS256").decode("ascii")

def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel

def test_login_jwt_valid_registered(self):
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_valid_unregistered(self):
channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "JWT expired")

def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

def test_login_no_token(self):
params = json.dumps({"type": "m.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")


# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
]

# This key's pubkey is used as the jwt_secret setting of synapse. Valid
# tokens are signed by this and validated using the pubkey. It is generated
# with `openssl genrsa 512` (not a secure way to generate real keys, but
# good enough for tests!)
jwt_privatekey = "\n".join(
[
"-----BEGIN RSA PRIVATE KEY-----",
"MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
"492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
"yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
"kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
"TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
"ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
"tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
"-----END RSA PRIVATE KEY-----",
]
)

# Generated with `openssl rsa -in foo.key -pubout`, with the the above
# private key placed in foo.key (jwt_privatekey).
jwt_pubkey = "\n".join(
[
"-----BEGIN PUBLIC KEY-----",
"MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
"TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
"-----END PUBLIC KEY-----",
]
)

# This key is used to sign tokens that shouldn't be accepted by synapse.
# Generated just like jwt_privatekey.
bad_privatekey = "\n".join(
[
"-----BEGIN RSA PRIVATE KEY-----",
"MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
"gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
"R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
"uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
"eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
"iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
"KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
"-----END RSA PRIVATE KEY-----",
]
)

def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_pubkey
self.hs.config.jwt_algorithm = "RS256"
return self.hs

def jwt_encode(self, token, secret=jwt_privatekey):
return jwt.encode(token, secret, "RS256").decode("ascii")

def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel

def test_login_jwt_valid(self):
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")