Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

MSC2918 Refresh tokens implementation #9450

Merged
merged 51 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fe80ef5
WIP: MSC2918
sandhose Feb 13, 2021
523d8cf
MSC2918: implement refresh tokens
sandhose Feb 19, 2021
358da22
MSC2918: Changelog
sandhose Mar 26, 2021
f53466e
MSC2918: fix mypy and lint errors
sandhose Mar 26, 2021
324d7bf
MSC2918: add PostgreSQL schema
sandhose Mar 26, 2021
450a962
MSC2918: do not invalidate refresh token immediately & fix tests
sandhose Apr 9, 2021
022485e
MSC2918: lint fixes
sandhose Apr 9, 2021
51ba1c3
MSC2918: also delete refresh tokens when logging out
sandhose Apr 22, 2021
d281f7e
MSC2918: fix field name in migrations
sandhose Apr 22, 2021
f499d63
MSC2918: merge SQLite and PostgreSQL schema deltas
sandhose May 5, 2021
e402a07
MSC2918: fix sample config
sandhose May 5, 2021
adc6eab
MSC2918: use parse_boolean to get query parameter value
sandhose May 5, 2021
6963fe0
MSC2918: use attr.s instead of TypedDict
sandhose May 5, 2021
318b74c
MSC2918: remove unused sequence in refresh_tokens
sandhose May 5, 2021
29806b4
MSC2918: try fixing port_db script when a table references itself
sandhose May 5, 2021
72e5c25
MSC2918: lint
sandhose May 5, 2021
eb9f680
Revert "MSC2918: use attr.s instead of TypedDict"
sandhose May 5, 2021
417a34a
MSC2918: random signed token instead of macaroons for refresh tokens
sandhose May 20, 2021
45177a6
MSC2918: some docstrings and minor changes
sandhose May 20, 2021
e37f53a
MSC2918: expires_in -> expires_in_ms
sandhose May 27, 2021
262d1ab
MSC2918: properly figure out whether an access token was already used…
sandhose May 27, 2021
75ce9e5
MSC2918: implement for registration endpoint
sandhose May 27, 2021
6f2cc61
MSC2918: properly replace old-next refresh token
sandhose May 27, 2021
b7b17ed
MSC2918: add tests
sandhose May 27, 2021
c7eab51
MSC2918: use secrets.token_bytes instead of random.randbytes
sandhose May 27, 2021
088e023
MSC2918: mark new column as boolean in port_db
sandhose May 27, 2021
6247228
MSC2918: fix existing auth test
sandhose May 27, 2021
67d4c9e
Merge remote-tracking branch 'upstream/develop' into sandhose/msc2918
sandhose May 28, 2021
2ec853c
MSC2918: use the same pattern as access tokens for refresh tokens
sandhose May 28, 2021
9e7ce1f
MSC2918: lint: remove unused import
sandhose May 28, 2021
45e2eaf
MSC2918: fix typing issue
sandhose May 28, 2021
c20f94a
MSC2918: properly check refresh_token parameter
sandhose May 28, 2021
790baac
MSC2918: cleanup old refresh token generation code
sandhose May 28, 2021
01b0740
MSC2918: add more docstrings
sandhose Jun 3, 2021
797e0d3
MSC2918: change refresh token API error codes
sandhose Jun 3, 2021
8f8f369
MSC2918: disable refresh tokens when session_lifetime is set
sandhose Jun 3, 2021
6024ed8
MSC2918: add missing docstring
sandhose Jun 3, 2021
908c279
MSC2918: temp: mark the access token as used only once
sandhose Jun 3, 2021
cdfd871
MSC2918: explicit cast on access_tokens.used
sandhose Jun 4, 2021
b169a62
Revert "MSC2918: explicit cast on access_tokens.used"
sandhose Jun 4, 2021
e07ef9b
MSC2918: properly fix access_tokens.used column on old SQLite
sandhose Jun 4, 2021
4cf49a6
Merge remote-tracking branch 'upstream/develop' into sandhose/msc2918
sandhose Jun 4, 2021
ef0e051
MSC2918: properly fix "mark_access_token_as_used" by caching it
sandhose Jun 4, 2021
7adfe0c
Merge remote-tracking branch 'upstream/develop' into sandhose/msc2918
sandhose Jun 10, 2021
ab443a3
MSC2918: add comments as suggested by richvdh
sandhose Jun 17, 2021
0060bc9
Merge remote-tracking branch 'upstream/develop' into sandhose/msc2918
sandhose Jun 17, 2021
18628fc
MSC2918: make access_tokens.used nullable
sandhose Jun 18, 2021
bcc33e2
MSC2918: 403 when using a refresh token twice
sandhose Jun 18, 2021
ddfc2a4
MSC2918: clarify comment about access_token_lifetime and session_life…
sandhose Jun 18, 2021
a013064
Merge remote-tracking branch 'upstream/develop' into sandhose/msc2918
sandhose Jun 18, 2021
9fe5556
MSC2918: fix refresh token invalidation test
sandhose Jun 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9450.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918).
4 changes: 4 additions & 0 deletions docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ account_validity:
#
#session_lifetime: 24h

# MSC2918
# TODO: docs
access_token_lifetime: 5m
richvdh marked this conversation as resolved.
Show resolved Hide resolved

# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
Expand Down
8 changes: 8 additions & 0 deletions synapse/config/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def read_config(self, config, **kwargs):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime

access_token_lifetime = config.get("access_token_lifetime", "5m")
access_token_lifetime = self.parse_duration(access_token_lifetime)
self.access_token_lifetime = access_token_lifetime # type: int
richvdh marked this conversation as resolved.
Show resolved Hide resolved

# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")

Expand Down Expand Up @@ -282,6 +286,10 @@ def generate_config_section(self, generate_secrets=False, **kwargs):
#
#session_lifetime: 24h

# MSC2918
# TODO: docs
access_token_lifetime: 5m

# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
Expand Down
75 changes: 70 additions & 5 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Optional,
Tuple,
Union,
cast,
)

import attr
Expand Down Expand Up @@ -70,6 +71,7 @@
from synapse.util.threepids import canonicalise_email

if TYPE_CHECKING:
from synapse.rest.client.v1.login import LoginResponse
from synapse.server import HomeServer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -770,13 +772,59 @@ def _auth_dict_for_flows(
"params": params,
}

async def refresh_token(
self,
refresh_token: str,
valid_until_ms: Optional[int],
) -> Tuple[str, str]:
richvdh marked this conversation as resolved.
Show resolved Hide resolved
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
raise SynapseError(400, "refresh token does not exist")

if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(400, "refresh token isn't valid anymore")
richvdh marked this conversation as resolved.
Show resolved Hide resolved

(
new_refresh_token,
new_refresh_token_id,
) = await self.get_refresh_token_for_user_id(
user_id=existing_token.user_id, device_id=existing_token.device_id
)
access_token = await self.get_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
valid_until_ms=valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
return access_token, new_refresh_token

async def get_refresh_token_for_user_id(
self,
user_id: str,
device_id: Optional[str],
) -> Tuple[str, int]:
richvdh marked this conversation as resolved.
Show resolved Hide resolved
refresh_token = self.macaroon_gen.generate_refresh_token(user_id)
richvdh marked this conversation as resolved.
Show resolved Hide resolved
refresh_token_id = await self.store.add_refresh_token_to_user(
user_id=user_id,
token=refresh_token,
device_id=device_id,
)
return refresh_token, refresh_token_id

async def get_access_token_for_user_id(
self,
user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
refresh_token_id: Optional[int] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
Expand Down Expand Up @@ -827,6 +875,7 @@ async def get_access_token_for_user_id(
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
refresh_token_id=refresh_token_id,
)

# the device *should* have been registered before we got here; however,
Expand Down Expand Up @@ -919,7 +968,7 @@ async def validate_login(
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API

Also used by the user-interactive auth flow to validate auth types which don't
Expand Down Expand Up @@ -1064,7 +1113,7 @@ async def _validate_userid_login(
self,
username: str,
login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login

Handles login, once we've mapped 3pids onto userids
Expand Down Expand Up @@ -1142,7 +1191,7 @@ async def _validate_userid_login(

async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login

Args:
Expand Down Expand Up @@ -1541,7 +1590,7 @@ def _complete_sso_login(
)
respond_with_html(request, 200, html)

async def _sso_login_callback(self, login_result: JsonDict) -> None:
async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
"""
A login callback which might add additional attributes to the login response.

Expand All @@ -1555,7 +1604,8 @@ async def _sso_login_callback(self, login_result: JsonDict) -> None:

extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
login_result.update(extra_attributes.extra_attributes)
login_result_dict = cast(Dict[str, Any], login_result)
login_result_dict.update(extra_attributes.extra_attributes)

def _expire_sso_extra_attributes(self) -> None:
"""
Expand Down Expand Up @@ -1586,6 +1636,21 @@ class MacaroonGenerator:

hs = attr.ib()

def generate_refresh_token(
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = refresh")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat(
"nonce = %s" % (stringutils.random_string_with_symbols(16),)
)
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()

def generate_access_token(
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
Expand Down
52 changes: 46 additions & 6 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"""Contains functions for registering clients."""

import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple

from prometheus_client import Counter
from typing_extensions import TypedDict

from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
Expand Down Expand Up @@ -55,6 +56,16 @@
["guest", "auth_provider"],
)

LoginDict = TypedDict(
"LoginDict",
{
"device_id": str,
"access_token": str,
"valid_until_ms": Optional[int],
"refresh_token": Optional[str],
},
)
richvdh marked this conversation as resolved.
Show resolved Hide resolved


class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
Expand Down Expand Up @@ -86,6 +97,7 @@ def __init__(self, hs: "HomeServer"):
self.pusher_pool = hs.get_pusherpool()

self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime

async def check_username(
self,
Expand Down Expand Up @@ -666,7 +678,8 @@ async def register_device(
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
should_issue_refresh_token: bool = False,
) -> Tuple[str, str, Optional[int], Optional[str]]:
sandhose marked this conversation as resolved.
Show resolved Hide resolved
"""Register a device for a user and generate an access token.

The access token will be limited by the homeserver's session_lifetime config.
Expand All @@ -678,23 +691,30 @@ async def register_device(
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
should_issue_refresh_token: Whether it should also issue a refresh token
Returns:
Tuple of device ID and access token
Tuple of device ID, access token, access token expiration time and refresh token
"""
res = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
)

login_counter.labels(
guest=is_guest,
auth_provider=(auth_provider_id or ""),
).inc()

return res["device_id"], res["access_token"]
return (
res["device_id"],
res["access_token"],
res["valid_until_ms"],
res["refresh_token"],
)

async def register_device_inner(
self,
Expand All @@ -703,7 +723,8 @@ async def register_device_inner(
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Dict[str, str]:
should_issue_refresh_token: bool = False,
) -> LoginDict:
"""Helper for register_device

Does the bits that need doing on the main process. Not for use outside this
Expand All @@ -718,6 +739,9 @@ class and RegisterDeviceReplicationServlet.
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime

refresh_token = None
refresh_token_id = None

registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
Expand All @@ -727,14 +751,30 @@ class and RegisterDeviceReplicationServlet.
user_id, ["guest = true"]
)
else:
if should_issue_refresh_token:
(
refresh_token,
refresh_token_id,
) = await self._auth_handler.get_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
)
valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
richvdh marked this conversation as resolved.
Show resolved Hide resolved

access_token = await self._auth_handler.get_access_token_for_user_id(
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)

return {"device_id": registered_device_id, "access_token": access_token}
return {
"device_id": registered_device_id,
"access_token": access_token,
"valid_until_ms": valid_until_ms,
"refresh_token": refresh_token,
}

async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
Expand Down
2 changes: 1 addition & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def register(self, localpart, displayname=None, emails: Optional[List[str]] = No
"Using deprecated ModuleApi.register which creates a dummy user device."
)
user_id = yield self.register_user(localpart, displayname, emails or [])
_, access_token = yield self.register_device(user_id)
_, access_token, _, _ = yield self.register_device(user_id)
return user_id, access_token

def register_user(
Expand Down
10 changes: 9 additions & 1 deletion synapse/replication/http/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def __init__(self, hs):

@staticmethod
async def _serialize_payload(
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost,
should_issue_refresh_token,
):
"""
Args:
Expand All @@ -51,6 +56,7 @@ async def _serialize_payload(
"initial_display_name": initial_display_name,
"is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
"should_issue_refresh_token": should_issue_refresh_token,
}

async def _handle_request(self, request, user_id):
Expand All @@ -60,13 +66,15 @@ async def _handle_request(self, request, user_id):
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
should_issue_refresh_token = content["should_issue_refresh_token"]

res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
)

return 200, res
Expand Down
Loading