Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support generic oauth2 #2211

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cps/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,17 @@ def _configuration_oauth_helper(to_save):
{"oauth_client_id": to_save["config_" + str(element['id']) + "_oauth_client_id"],
"oauth_client_secret": to_save["config_" + str(element['id']) + "_oauth_client_secret"],
"active": element["active"]})
if element['id'] == 3:
ub.session.query(ub.OAuthProvider).filter(ub.OAuthProvider.id == element['id']).update({
"oauth_base_url": to_save["config_" + str(element['id']) + "_oauth_base_url"],
"oauth_auth_url": to_save["config_" + str(element['id']) + "_oauth_auth_url"],
"oauth_token_url": to_save["config_" + str(element['id']) + "_oauth_token_url"],
"username_mapper": to_save["config_" + str(element['id']) + "_username_mapper"],
"email_mapper": to_save["config_" + str(element['id']) + "_email_mapper"],
"login_button": to_save["config_" + str(element['id']) + "_login_button"],
"scope": to_save["config_" + str(element['id']) + "_scope"],
})

return reboot_required


Expand Down
156 changes: 147 additions & 9 deletions cps/oauth_bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from flask import session, request, make_response, abort
from flask import Blueprint, flash, redirect, url_for
from flask_babel import gettext as _
from flask_dance.consumer import oauth_authorized, oauth_error
from flask_dance.consumer import oauth_authorized, oauth_error, OAuth2ConsumerBlueprint
from flask_dance.contrib.github import make_github_blueprint, github
from flask_dance.contrib.google import make_google_blueprint, google
from oauthlib.oauth2 import TokenExpiredError, InvalidGrantError
from flask_login import login_user, current_user, login_required
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.expression import func, and_

from . import constants, logger, config, app, ub

Expand All @@ -45,6 +46,7 @@
oauthblueprints = []
oauth = Blueprint('oauth', __name__)
log = logger.create()
generic = None


def oauth_required(f):
Expand Down Expand Up @@ -95,6 +97,7 @@ def logout_oauth_user():
for oauth_key in oauth_check.keys():
if str(oauth_key) + '_oauth_user_id' in session:
session.pop(str(oauth_key) + '_oauth_user_id')
unlink_oauth(oauth_key)


def oauth_update_token(provider_id, token, provider_user_id):
Expand Down Expand Up @@ -206,8 +209,10 @@ def unlink_oauth(provider):
return redirect(url_for('web.profile'))

def generate_oauth_blueprints():
global generic

if not ub.session.query(ub.OAuthProvider).count():
for provider in ("github", "google"):
for provider in ("github", "google", "generic"):
oauthProvider = ub.OAuthProvider()
oauthProvider.provider_name = provider
oauthProvider.active = False
Expand All @@ -229,20 +234,52 @@ def generate_oauth_blueprints():
oauth_client_id=oauth_ids[1].oauth_client_id,
oauth_client_secret=oauth_ids[1].oauth_client_secret,
obtain_link='https://console.developers.google.com/apis/credentials')
ele3 = dict(provider_name='generic',
id=oauth_ids[2].id,
active=oauth_ids[2].active,
scope=oauth_ids[2].scope,
oauth_client_id=oauth_ids[2].oauth_client_id,
oauth_client_secret=oauth_ids[2].oauth_client_secret,
oauth_base_url=oauth_ids[2].oauth_base_url,
oauth_auth_url=oauth_ids[2].oauth_auth_url,
oauth_token_url=oauth_ids[2].oauth_token_url,
username_mapper=oauth_ids[2].username_mapper,
email_mapper=oauth_ids[2].email_mapper,
login_button=oauth_ids[2].login_button)
oauthblueprints.append(ele1)
oauthblueprints.append(ele2)
oauthblueprints.append(ele3)

for element in oauthblueprints:
if element['provider_name'] == 'github':
blueprint_func = make_github_blueprint
else:
elif element['provider_name'] == 'google':
blueprint_func = make_google_blueprint
blueprint = blueprint_func(
client_id=element['oauth_client_id'],
client_secret=element['oauth_client_secret'],
redirect_to="oauth."+element['provider_name']+"_login",
scope=element['scope']
)
else:
blueprint_func = OAuth2ConsumerBlueprint

if element['provider_name'] in ('github', 'google'):
blueprint = blueprint_func(
client_id=element['oauth_client_id'],
client_secret=element['oauth_client_secret'],
redirect_url="oauth."+element['provider_name']+"_login",
scope=element['scope']
)
else:
base_url = element.get('oauth_base_url') or ''
token_url = element.get('oauth_token_url') or ''
auth_url = element.get('oauth_auth_url') or ''
blueprint = blueprint_func(
"generic",
__name__,
client_id=element['oauth_client_id'],
client_secret=element['oauth_client_secret'],
base_url=base_url,
authorization_url=base_url + auth_url,
token_url=base_url + token_url,
redirect_to='oauth.'+element['provider_name']+'_login',
)
generic = blueprint
element['blueprint'] = blueprint
element['blueprint'].backend = OAuthBackend(ub.OAuth, ub.session, str(element['id']),
user=current_user, user_required=True)
Expand Down Expand Up @@ -291,6 +328,55 @@ def google_logged_in(blueprint, token):
return oauth_update_token(str(oauthblueprints[1]['id']), token, google_user_id)


@oauth_authorized.connect_via(oauthblueprints[2]['blueprint'])
def generic_logged_in(blueprint, token):
global generic

if not token:
flash(_(u"Failed to log in with generic OAuth provider."), category="error")
log.error("Failed to log in with generic OAuth2 provider")
return False

resp = blueprint.session.get(blueprint.base_url + "/protocol/openid-connect/userinfo")
if not resp.ok:
flash(_(u"Failed to fetch user info from generic OAuth2 provider."), category="error")
log.error("Failed to fetch user info from generic OAuth2 provider")
return False

username_mapper = oauthblueprints[2].get('username_mapper') or 'username'
email_mapper = oauthblueprints[2].get('email_mapper') or 'email'

generic_info = resp.json()
generic_user_email = str(generic_info[email_mapper])
generic_user_username = str(generic_info[username_mapper])

user = (
ub.session.query(ub.User)
.filter(and_(func.lower(ub.User.name) == generic_user_username,
func.lower(ub.User.email) == generic_user_email))
).first()

if user is None:
user = ub.User()
user.name = generic_user_username
user.email = generic_user_email
user.role = constants.ROLE_USER
ub.session.add(user)
ub.session_commit()

result = oauth_update_token(str(oauthblueprints[2]['id']), token, user.id)

query = ub.session.query(ub.OAuth).filter_by(
provider=str(oauthblueprints[2]['id']),
provider_user_id=user.id,
)
oauth_entry = query.first()
oauth_entry.user = user
ub.session_commit()

return result


# notify on OAuth provider error
@oauth_error.connect_via(oauthblueprints[0]['blueprint'])
def github_error(blueprint, error, error_description=None, error_uri=None):
Expand Down Expand Up @@ -319,6 +405,20 @@ def google_error(blueprint, error, error_description=None, error_uri=None):
flash(msg, category="error")


@oauth_error.connect_via(oauthblueprints[2]['blueprint'])
def generic_error(blueprint, error, error_description=None, error_uri=None):
msg = (
u"OAuth error from {name}! "
u"error={error} description={description} uri={uri}"
).format(
name=blueprint.name,
error=error,
description=error_description,
uri=error_uri,
) # ToDo: Translate
flash(msg, category="error")


@oauth.route('/link/github')
@oauth_required
def github_login():
Expand Down Expand Up @@ -365,3 +465,41 @@ def google_login():
@login_required
def google_login_unlink():
return unlink_oauth(oauthblueprints[1]['id'])


@oauth.route('/link/generic')
@oauth_required
def generic_login():
global generic

if not generic.session.authorized:
return redirect(url_for("generic.login"))
try:
resp = generic.session.get(generic.base_url + "/protocol/openid-connect/userinfo")
if resp.ok:
account_info_json = resp.json()

username_mapper = oauthblueprints[2].get('username_mapper') or 'username'
email_mapper = oauthblueprints[2].get('email_mapper') or 'email'

email = str(account_info_json[email_mapper])
username = str(account_info_json[username_mapper])

user = (
ub.session.query(ub.User)
.filter(and_(func.lower(ub.User.name) == username,
func.lower(ub.User.email) == email))
).first()

return bind_oauth_or_register(oauthblueprints[2]['id'], user.id, 'generic.login', 'generic')
flash(_(u"generic OAuth2 error, please retry later."), category="error")
log.error("generic OAuth2 error, please retry later")
except (InvalidGrantError, TokenExpiredError) as e:
log.error(e)
return redirect(url_for("generic.login"))


@oauth.route('/unlink/generic', methods=["GET"])
@login_required
def generic_login_unlink():
return unlink_oauth(oauthblueprints[2]['id'])
45 changes: 45 additions & 0 deletions cps/templates/config_edit.html
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,12 @@ <h4 class="text-center">{{_('Following Settings are Needed For User Import')}}</
{% if feature_support['oauth'] %}
<div data-related="login-settings-2">
{% for prov in provider %}
<h4> {{prov['provider_name']}} </h4>
{% if prov.obtain_link %}
<div class="form-group">
<a href="{{prov['obtain_link']}}" target="_blank">{{_('Obtain %(provider)s OAuth Credential', provider=prov['provider_name'])}}</a>
</div>
{% endif %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_oauth_client_id">{{_('%(provider)s OAuth Client Id', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_oauth_client_id" name="config_{{ prov['id'] }}_oauth_client_id" value="{% if prov['oauth_client_id']%}{{ prov['oauth_client_id'] }}{% endif %}" autocomplete="off">
Expand All @@ -304,6 +307,48 @@ <h4 class="text-center">{{_('Following Settings are Needed For User Import')}}</
<label for="config_{{ prov['id'] }}_oauth_client_secret">{{_('%(provider)s OAuth Client Secret', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_oauth_client_secret" name="config_{{ prov['id'] }}_oauth_client_secret" value="{% if prov['oauth_client_secret']%}{{ prov['oauth_client_secret'] }}{% endif %}" autocomplete="off">
</div>
{% if 'scope' in prov and 'generic' == prov['provider_name'] %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_scope">{{_('%(provider)s OAuth scope', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_scope" name="config_{{ prov['id'] }}_scope" value="{% if prov['scope']%}{{ prov['scope'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'oauth_base_url' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_oauth_base_url">{{_('%(provider)s OAuth Base URL', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_oauth_base_url" name="config_{{ prov['id'] }}_oauth_base_url" value="{% if prov['oauth_base_url']%}{{ prov['oauth_base_url'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'oauth_auth_url' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_oauth_auth_url">{{_('%(provider)s OAuth Auth URL (relative)', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_oauth_auth_url" name="config_{{ prov['id'] }}_oauth_auth_url" value="{% if prov['oauth_auth_url']%}{{ prov['oauth_auth_url'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'oauth_token_url' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_oauth_token_url">{{_('%(provider)s OAuth Token URL (relative)', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_oauth_token_url" name="config_{{ prov['id'] }}_oauth_token_url" value="{% if prov['oauth_token_url']%}{{ prov['oauth_token_url'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'username_mapper' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_username_mapper">{{_('%(provider)s OAuth Username mapper', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_username_mapper" name="config_{{ prov['id'] }}_username_mapper" value="{% if prov['username_mapper']%}{{ prov['username_mapper'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'email_mapper' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_email_mapper">{{_('%(provider)s OAuth Email mapper', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_email_mapper" name="config_{{ prov['id'] }}_email_mapper" value="{% if prov['email_mapper']%}{{ prov['email_mapper'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% if 'login_button' in prov %}
<div class="form-group">
<label for="config_{{ prov['id'] }}_login_button">{{_('%(provider)s OAuth Login Button', provider=prov['provider_name'])}}</label>
<input type="text" class="form-control" id="config_{{ prov['id'] }}_login_button" name="config_{{ prov['id'] }}_login_button" value="{% if prov['login_button']%}{{ prov['login_button'] }}{% endif %}" autocomplete="off">
</div>
{% endif %}
{% endfor %}
</div>
{% endif %}
Expand Down
3 changes: 3 additions & 0 deletions cps/templates/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ <h2 style="margin-top: 0">{{_('Login')}}</h2>
style="fill:#000000;"><g id="surface1"><path style=" fill:#FFC107;" d="M 43.609375 20.082031 L 42 20.082031 L 42 20 L 24 20 L 24 28 L 35.304688 28 C 33.652344 32.65625 29.222656 36 24 36 C 17.371094 36 12 30.628906 12 24 C 12 17.371094 17.371094 12 24 12 C 27.058594 12 29.84375 13.152344 31.960938 15.039063 L 37.617188 9.382813 C 34.046875 6.054688 29.269531 4 24 4 C 12.953125 4 4 12.953125 4 24 C 4 35.046875 12.953125 44 24 44 C 35.046875 44 44 35.046875 44 24 C 44 22.660156 43.863281 21.351563 43.609375 20.082031 Z "></path><path style=" fill:#FF3D00;" d="M 6.304688 14.691406 L 12.878906 19.511719 C 14.65625 15.109375 18.960938 12 24 12 C 27.058594 12 29.84375 13.152344 31.960938 15.039063 L 37.617188 9.382813 C 34.046875 6.054688 29.269531 4 24 4 C 16.316406 4 9.65625 8.335938 6.304688 14.691406 Z "></path><path style=" fill:#4CAF50;" d="M 24 44 C 29.164063 44 33.859375 42.023438 37.410156 38.808594 L 31.21875 33.570313 C 29.210938 35.089844 26.714844 36 24 36 C 18.796875 36 14.382813 32.683594 12.71875 28.054688 L 6.195313 33.078125 C 9.503906 39.554688 16.226563 44 24 44 Z "></path><path style=" fill:#1976D2;" d="M 43.609375 20.082031 L 42 20.082031 L 42 20 L 24 20 L 24 28 L 35.304688 28 C 34.511719 30.238281 33.070313 32.164063 31.214844 33.570313 C 31.21875 33.570313 31.21875 33.570313 31.21875 33.570313 L 37.410156 38.808594 C 36.972656 39.203125 44 34 44 24 C 44 22.660156 43.863281 21.351563 43.609375 20.082031 Z "></path></g></svg>
</a>
{% endif %}
{% if 3 in oauth_check %}
<a href="{{url_for('oauth.generic_login')}}" class="pull-right generic">Log in with <b>{{ login_button }}</b></a>
{% endif %}
{% endif %}
</form>
</div>
Expand Down
11 changes: 9 additions & 2 deletions cps/ub.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ class OAuthProvider(Base):
provider_name = Column(String)
oauth_client_id = Column(String)
oauth_client_secret = Column(String)
oauth_base_url = Column(String)
oauth_auth_url = Column(String, default="/protocol/openid-connect/auth")
oauth_token_url = Column(String, default="/protocol/openid-connect/token")
scope = Column(String, default="openid profile email")
username_mapper = Column(String, default="preferred_username")
email_mapper = Column(String, default="email")
login_button = Column(String)
active = Column(Boolean)


Expand Down Expand Up @@ -688,13 +695,13 @@ def migrate_Database(session):
"kindle_mail VARCHAR(120),"
"locale VARCHAR(2),"
"sidebar_view INTEGER,"
"default_language VARCHAR(3),"
"default_language VARCHAR(3),"
"denied_tags VARCHAR,"
"allowed_tags VARCHAR,"
"denied_column_value VARCHAR,"
"allowed_column_value VARCHAR,"
"view_settings JSON,"
"kobo_only_shelves_sync SMALLINT,"
"kobo_only_shelves_sync SMALLINT,"
"UNIQUE (name),"
"UNIQUE (email))"))
conn.execute(text("INSERT INTO user_id(id, name, email, role, password, kindle_mail,locale,"
Expand Down
7 changes: 7 additions & 0 deletions cps/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,11 +1599,18 @@ def login():
next_url = request.args.get('next', default=url_for("web.index"), type=str)
if url_for("web.logout") == next_url:
next_url = url_for("web.index")

login_button = "generic oauth2 provider"
if 3 in oauth_check:
from .oauth_bb import oauthblueprints
login_button = oauthblueprints[2].get('login_button') or login_button

return render_title_template('login.html',
title=_(u"Login"),
next_url=next_url,
config=config,
oauth_check=oauth_check,
login_button=login_button,
mail=config.get_mail_server_configured(), page="login")


Expand Down