Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Store auth provider id in login tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Feb 28, 2021
1 parent 3d4902b commit f2466b2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
17 changes: 15 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,7 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s
async def complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
Expand All @@ -1415,6 +1416,9 @@ async def complete_sso_login(
Args:
registered_user_id: The registered user ID to complete SSO login for.
auth_provider_id: The id of the SSO Identity provider that was used for
login. This will be stored in the login token for future tracking in
prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
Expand All @@ -1436,6 +1440,7 @@ async def complete_sso_login(

self._complete_sso_login(
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
extra_attributes,
Expand All @@ -1446,6 +1451,7 @@ async def complete_sso_login(
def _complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
Expand All @@ -1472,7 +1478,7 @@ def _complete_sso_login(

# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
registered_user_id, auth_provider_id=auth_provider_id
)

# Append the login token to the original redirect URL (i.e. with its query
Expand Down Expand Up @@ -1578,13 +1584,20 @@ def generate_access_token(
return macaroon.serialize()

def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
if auth_provider_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_id = %s" % (auth_provider_id,)
)
return macaroon.serialize()

def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
Expand Down
2 changes: 2 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ async def complete_sso_login_request(

await self._auth_handler.complete_sso_login(
user_id,
auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
Expand Down Expand Up @@ -886,6 +887,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None:

await self._auth_handler.complete_sso_login(
user_id,
session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
Expand Down
31 changes: 27 additions & 4 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,26 @@ def record_user_external_id(
)

def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication"""
"""Generate a login token suitable for m.login.token authentication
Args:
user_id: gives the ID of the user that the token is for
duration_in_ms: the time that the token will be valid for
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
to get this token, if any. This is encoded in the token so that
/login can report stats on number of successful logins by IdP.
"""
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, duration_in_ms
user_id,
duration_in_ms,
auth_provider_id=auth_provider_id,
)

@defer.inlineCallbacks
Expand Down Expand Up @@ -276,6 +291,7 @@ def complete_sso_login(
"""
self._auth_handler._complete_sso_login(
registered_user_id,
"<unknown>",
request,
client_redirect_url,
)
Expand All @@ -286,6 +302,7 @@ async def complete_sso_login_async(
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
Expand All @@ -299,9 +316,15 @@ async def complete_sso_login_async(
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
auth_provider_id: the ID of the SSO IdP which was used to log in. This
is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url, new_user=new_user
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
new_user=new_user,
)

@defer.inlineCallbacks
Expand Down
8 changes: 8 additions & 0 deletions tests/handlers/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def test_short_term_login_token_gives_user_id(self):
AuthError,
)

def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", auth_provider_id="my_idp"
)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)

def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
Expand Down

0 comments on commit f2466b2

Please sign in to comment.