Skip to content

Commit

Permalink
OAuth2 refresh token support
Browse files Browse the repository at this point in the history
  • Loading branch information
mesemus committed May 14, 2024
1 parent 4aaa9c5 commit 00a3663
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Add expires_at and refresh_token to remote token."""

import sqlalchemy as sa
import sqlalchemy_utils
from alembic import op

# revision identifiers, used by Alembic.
revision = "7def990b852e"
down_revision = "aaa265b0afa6"
branch_labels = ()
depends_on = ("aaa265b0afa6",)


def upgrade():
"""Upgrade database."""
op.add_column(
"oauthclient_remotetoken",
sa.Column("refresh_token", sqlalchemy_utils.EncryptedType(), nullable=True),
)
op.add_column(
"oauthclient_remotetoken", sa.Column("expires_at", sa.DateTime(), nullable=True)
)


def downgrade():
"""Downgrade database."""
op.drop_column("oauthclient_remotetoken", "expires_at")
op.drop_column("oauthclient_remotetoken", "refresh_token")
9 changes: 9 additions & 0 deletions invenio_oauthclient/alembic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Alembic migrations for Invenio-OAuthClient."""
4 changes: 2 additions & 2 deletions invenio_oauthclient/contrib/keycloak/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
app_key=None,
icon=None,
scopes="openid",
**kwargs
**kwargs,
):
"""The constructor takes two arguments.
Expand All @@ -64,7 +64,7 @@ def __init__(
request_token_params={"scope": scopes},
access_token_url=access_token_url,
authorize_url=authorize_url,
**kwargs
**kwargs,
)

self._handlers = dict(
Expand Down
9 changes: 8 additions & 1 deletion invenio_oauthclient/handlers/authorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,14 @@ def extra_signup_handler(remote, form, *args, **kwargs):
user = _register_user(response, remote, account_info, form)

# Link account and set session data
token = token_setter(remote, oauth_token[0], secret=oauth_token[1], user=user)
token = token_setter(
remote,
oauth_token[0],
secret=oauth_token[1],
user=user,
refresh_token=oauth_token[2] if len(oauth_token) > 2 else None,
expires_at=oauth_token[3] if len(oauth_token) > 3 else None,
)
if token is None:
raise OAuthClientTokenNotSet()

Expand Down
53 changes: 47 additions & 6 deletions invenio_oauthclient/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Funcs to manage tokens."""

import datetime
from functools import partial

from flask import current_app, session
Expand Down Expand Up @@ -112,10 +112,31 @@ def oauth2_token_setter(remote, resp, token_type="", extra_data=None):
secret="",
token_type=token_type,
extra_data=extra_data,
refresh_token=resp.get("refresh_token"),
expires_at=make_expiration_time(resp.get("expires_in")),
)


def token_setter(remote, token, secret="", token_type="", extra_data=None, user=None):
def make_expiration_time(expires_in):
"""Make expiration time from expires_in.
:param expires_in: The time in seconds.
"""
if expires_in is None:
return None
return datetime.datetime.now() + datetime.timedelta(seconds=expires_in)


def token_setter(
remote,
token,
secret="",
token_type="",
extra_data=None,
user=None,
refresh_token=None,
expires_at=None,
):
"""Set token for user.
:param remote: The remote application.
Expand All @@ -127,7 +148,12 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user=
:returns: A :class:`invenio_oauthclient.models.RemoteToken` instance or
``None``.
"""
session[token_session_key(remote.name)] = (token, secret)
session[token_session_key(remote.name)] = (
token,
secret,
refresh_token,
expires_at.isoformat() if expires_at else None,
)
user = user or current_user

# Save token if user is not anonymous (user exists but can be not active at
Expand All @@ -140,10 +166,17 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user=
t = RemoteToken.get(uid, cid, token_type=token_type)

if t:
t.update_token(token, secret)
t.update_token(token, secret, refresh_token, expires_at)
else:
t = RemoteToken.create(
uid, cid, token, secret, token_type=token_type, extra_data=extra_data
uid,
cid,
token,
secret,
token_type=token_type,
extra_data=extra_data,
refresh_token=refresh_token,
expires_at=expires_at,
)
return t
return None
Expand Down Expand Up @@ -176,7 +209,15 @@ def token_getter(remote, token=""):
# Store token and secret in session
session[session_key] = remote_token.token()

return session.get(session_key, None)
ret = session.get(session_key, None)
if ret:
if len(ret) == 2:
# no refresh token nor expiration time
return ret[0], ret[1], None, None
if ret[3] is not None:
# refresh token and expiration time
return ret[0], ret[1], ret[2], datetime.datetime.fromisoformat(ret[3])
return ret


def token_delete(remote, token=""):
Expand Down
60 changes: 57 additions & 3 deletions invenio_oauthclient/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

"""Models for storing access tokens and links between users and remote apps."""

import datetime

from flask import current_app

# UserIdentity imported for backward compatibility. UserIdentity was originally
Expand Down Expand Up @@ -119,6 +121,14 @@ class RemoteToken(db.Model, Timestamp):
)
"""Access token to remote application."""

refresh_token = db.Column(
EncryptedType(type_in=db.Text, key=_secret_key), nullable=True
)
"""Refresh token to remote application."""

expires_at = db.Column(db.DateTime, nullable=True)
"""Access token expiration date."""

secret = db.Column(db.Text(), default="", nullable=False)
"""Used only by OAuth 1."""

Expand All @@ -130,6 +140,17 @@ class RemoteToken(db.Model, Timestamp):
)
"""SQLAlchemy relationship to RemoteAccount objects."""

@property
def is_expired(self):
"""Check if access token has expired."""
if not self.expires_at:
return False

leeway = current_app.config.get("OAUTHCLIENT_TOKEN_EXPIRES_LEEWAY", 10)
expiration_with_leeway = self.expires_at - datetime.timedelta(seconds=leeway)

return expiration_with_leeway < datetime.datetime.now()

def __repr__(self):
"""String representation for model."""
return (
Expand All @@ -141,18 +162,37 @@ def token(self):
"""Get token as expected by Flask-OAuthlib."""
return (self.access_token, self.secret)

def update_token(self, token, secret):
def update_token(self, token, secret, refresh_token=None, expires_at=None):
"""Update token with new values.
:param token: The token value.
:param secret: The secret key.
:param refresh_token: The refresh token
:param expires_at: Time when the access token expires
"""
if self.access_token != token or self.secret != secret:
if (
self.access_token != token
or self.secret != secret
or self.refresh_token != refresh_token
or self.expiration != expires_at
):
with db.session.begin_nested():
self.access_token = token
self.secret = secret
self.refresh_token = refresh_token
self.expires_at = expires_at
db.session.add(self)

def refresh_access_token(self):
"""Refresh the access token."""
if not self.refresh_token:
raise ValueError("No refresh token available")
from .handlers.refresh import refresh_access_token

access_token, refresh_token, secret, expires_at = refresh_access_token(self)
self.update_token(access_token, refresh_token, secret, expires_at)
db.session.commit()

@classmethod
def get(cls, user_id, client_id, token_type="", access_token=None):
"""Get RemoteToken for user.
Expand Down Expand Up @@ -197,7 +237,17 @@ def get_by_token(cls, client_id, access_token, token_type=""):
)

@classmethod
def create(cls, user_id, client_id, token, secret, token_type="", extra_data=None):
def create(
cls,
user_id,
client_id,
token,
secret,
token_type="",
extra_data=None,
refresh_token=None,
expires_at=None,
):
"""Create a new access token.
.. note:: Creates RemoteAccount as well if it does not exists.
Expand All @@ -209,6 +259,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non
:param token_type: The token type. (Default: ``''``)
:param extra_data: Extra data to set in the remote account if the
remote account doesn't exists. (Default: ``None``)
:param refresh_token: The refresh token.
:param expires_at: Expiration of the token
:returns: A :class:`invenio_oauthclient.models.RemoteToken` instance.
"""
Expand All @@ -228,6 +280,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non
remote_account=account,
access_token=token,
secret=secret,
refresh_token=refresh_token,
expires_at=expires_at,
)
db.session.add(token)
return token
Expand Down
3 changes: 1 addition & 2 deletions invenio_oauthclient/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
"""Utility methods."""

from flask import current_app, request, session
from flask_login import current_user
from flask_principal import RoleNeed, UserNeed
from flask_principal import RoleNeed
from invenio_db.utils import rebuild_encrypted_properties
from itsdangerous import TimedJSONWebSignatureSerializer
from uritools import uricompose, urisplit
Expand Down
5 changes: 5 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def mock_remote_get(oauth, remote_app="test", data=None):
oauth.remote_apps[remote_app].get = MagicMock(return_value=data)


def mock_remote_http_request(oauth, remote_app="test", data=None):
"""Mock the oauth remote get response."""
oauth.remote_apps[remote_app].http_request = MagicMock(return_value=data)


def check_redirect_location(resp, loc):
"""Check response redirect location."""
assert resp._status_code == 302
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def test_token_getter(remote, models_fixture, app):
# Populated RemoteToken
RemoteToken.create(user.id, "testkey", "mytoken", "mysecret")
oauth_authenticate("dev", user)
assert token_getter(remote) == ("mytoken", "mysecret")
assert token_getter(remote) == ("mytoken", "mysecret", None, None)
9 changes: 7 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def test_token_getter_setter(views_fixture, monkeypatch):
# Assert if everything is as it should be.
from flask import session as flask_session

assert flask_session["oauth_token_full"] == ("test_access_token", "")
assert flask_session["oauth_token_full"] == (
"test_access_token",
"",
None,
None,
)

t = RemoteToken.get(1, "fullid")
assert t.remote_account.client_id == "fullid"
Expand Down Expand Up @@ -423,7 +428,7 @@ def test_token_getter_setter(views_fixture, monkeypatch):
assert RemoteToken.query.count() == 1

val = token_getter(app.extensions["oauthlib.client"].remote_apps["full"])
assert val == ("new_access_token", "")
assert val == ("new_access_token", "", None, None)

# Disconnect account
res = c.get(
Expand Down
9 changes: 7 additions & 2 deletions tests/test_views_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,12 @@ def test_token_getter_setter(app_rest, monkeypatch):
# Assert if everything is as it should be.
from flask import session as flask_session

assert flask_session["oauth_token_full"] == ("test_access_token", "")
assert flask_session["oauth_token_full"] == (
"test_access_token",
"",
None,
None,
)

t = RemoteToken.get(1, "fullid")
assert t.remote_account.client_id == "fullid"
Expand Down Expand Up @@ -417,7 +422,7 @@ def test_token_getter_setter(app_rest, monkeypatch):
assert RemoteToken.query.count() == 1

val = token_getter(app_rest.extensions["oauthlib.client"].remote_apps["full"])
assert val == ("new_access_token", "")
assert val == ("new_access_token", "", None, None)

# Disconnect account
res = c.get(
Expand Down

0 comments on commit 00a3663

Please sign in to comment.