diff --git a/cps/admin.py b/cps/admin.py index f104aa29b8..3192eb9ba0 100644 --- a/cps/admin.py +++ b/cps/admin.py @@ -1048,6 +1048,17 @@ def _configuration_oauth_helper(to_save): {"oauth_client_id": to_save["config_" + str(element['id']) + "_oauth_client_id"], "oauth_client_secret": to_save["config_" + str(element['id']) + "_oauth_client_secret"], "active": element["active"]}) + if element['id'] == 3: + ub.session.query(ub.OAuthProvider).filter(ub.OAuthProvider.id == element['id']).update({ + "oauth_base_url": to_save["config_" + str(element['id']) + "_oauth_base_url"], + "oauth_auth_url": to_save["config_" + str(element['id']) + "_oauth_auth_url"], + "oauth_token_url": to_save["config_" + str(element['id']) + "_oauth_token_url"], + "username_mapper": to_save["config_" + str(element['id']) + "_username_mapper"], + "email_mapper": to_save["config_" + str(element['id']) + "_email_mapper"], + "login_button": to_save["config_" + str(element['id']) + "_login_button"], + "scope": to_save["config_" + str(element['id']) + "_scope"], + }) + return reboot_required diff --git a/cps/oauth_bb.py b/cps/oauth_bb.py index d9efd41e41..6ebd504968 100644 --- a/cps/oauth_bb.py +++ b/cps/oauth_bb.py @@ -26,12 +26,13 @@ from flask import session, request, make_response, abort from flask import Blueprint, flash, redirect, url_for from flask_babel import gettext as _ -from flask_dance.consumer import oauth_authorized, oauth_error +from flask_dance.consumer import oauth_authorized, oauth_error, OAuth2ConsumerBlueprint from flask_dance.contrib.github import make_github_blueprint, github from flask_dance.contrib.google import make_google_blueprint, google from oauthlib.oauth2 import TokenExpiredError, InvalidGrantError from flask_login import login_user, current_user, login_required from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.sql.expression import func, and_ from . import constants, logger, config, app, ub @@ -45,6 +46,7 @@ oauthblueprints = [] oauth = Blueprint('oauth', __name__) log = logger.create() +generic = None def oauth_required(f): @@ -95,6 +97,7 @@ def logout_oauth_user(): for oauth_key in oauth_check.keys(): if str(oauth_key) + '_oauth_user_id' in session: session.pop(str(oauth_key) + '_oauth_user_id') + unlink_oauth(oauth_key) def oauth_update_token(provider_id, token, provider_user_id): @@ -206,8 +209,10 @@ def unlink_oauth(provider): return redirect(url_for('web.profile')) def generate_oauth_blueprints(): + global generic + if not ub.session.query(ub.OAuthProvider).count(): - for provider in ("github", "google"): + for provider in ("github", "google", "generic"): oauthProvider = ub.OAuthProvider() oauthProvider.provider_name = provider oauthProvider.active = False @@ -229,20 +234,52 @@ def generate_oauth_blueprints(): oauth_client_id=oauth_ids[1].oauth_client_id, oauth_client_secret=oauth_ids[1].oauth_client_secret, obtain_link='https://console.developers.google.com/apis/credentials') + ele3 = dict(provider_name='generic', + id=oauth_ids[2].id, + active=oauth_ids[2].active, + scope=oauth_ids[2].scope, + oauth_client_id=oauth_ids[2].oauth_client_id, + oauth_client_secret=oauth_ids[2].oauth_client_secret, + oauth_base_url=oauth_ids[2].oauth_base_url, + oauth_auth_url=oauth_ids[2].oauth_auth_url, + oauth_token_url=oauth_ids[2].oauth_token_url, + username_mapper=oauth_ids[2].username_mapper, + email_mapper=oauth_ids[2].email_mapper, + login_button=oauth_ids[2].login_button) oauthblueprints.append(ele1) oauthblueprints.append(ele2) + oauthblueprints.append(ele3) for element in oauthblueprints: if element['provider_name'] == 'github': blueprint_func = make_github_blueprint - else: + elif element['provider_name'] == 'google': blueprint_func = make_google_blueprint - blueprint = blueprint_func( - client_id=element['oauth_client_id'], - client_secret=element['oauth_client_secret'], - redirect_to="oauth."+element['provider_name']+"_login", - scope=element['scope'] - ) + else: + blueprint_func = OAuth2ConsumerBlueprint + + if element['provider_name'] in ('github', 'google'): + blueprint = blueprint_func( + client_id=element['oauth_client_id'], + client_secret=element['oauth_client_secret'], + redirect_url="oauth."+element['provider_name']+"_login", + scope=element['scope'] + ) + else: + base_url = element.get('oauth_base_url') or '' + token_url = element.get('oauth_token_url') or '' + auth_url = element.get('oauth_auth_url') or '' + blueprint = blueprint_func( + "generic", + __name__, + client_id=element['oauth_client_id'], + client_secret=element['oauth_client_secret'], + base_url=base_url, + authorization_url=base_url + auth_url, + token_url=base_url + token_url, + redirect_to='oauth.'+element['provider_name']+'_login', + ) + generic = blueprint element['blueprint'] = blueprint element['blueprint'].backend = OAuthBackend(ub.OAuth, ub.session, str(element['id']), user=current_user, user_required=True) @@ -291,6 +328,55 @@ def google_logged_in(blueprint, token): return oauth_update_token(str(oauthblueprints[1]['id']), token, google_user_id) + @oauth_authorized.connect_via(oauthblueprints[2]['blueprint']) + def generic_logged_in(blueprint, token): + global generic + + if not token: + flash(_(u"Failed to log in with generic OAuth provider."), category="error") + log.error("Failed to log in with generic OAuth2 provider") + return False + + resp = blueprint.session.get(blueprint.base_url + "/protocol/openid-connect/userinfo") + if not resp.ok: + flash(_(u"Failed to fetch user info from generic OAuth2 provider."), category="error") + log.error("Failed to fetch user info from generic OAuth2 provider") + return False + + username_mapper = oauthblueprints[2].get('username_mapper') or 'username' + email_mapper = oauthblueprints[2].get('email_mapper') or 'email' + + generic_info = resp.json() + generic_user_email = str(generic_info[email_mapper]) + generic_user_username = str(generic_info[username_mapper]) + + user = ( + ub.session.query(ub.User) + .filter(and_(func.lower(ub.User.name) == generic_user_username, + func.lower(ub.User.email) == generic_user_email)) + ).first() + + if user is None: + user = ub.User() + user.name = generic_user_username + user.email = generic_user_email + user.role = constants.ROLE_USER + ub.session.add(user) + ub.session_commit() + + result = oauth_update_token(str(oauthblueprints[2]['id']), token, user.id) + + query = ub.session.query(ub.OAuth).filter_by( + provider=str(oauthblueprints[2]['id']), + provider_user_id=user.id, + ) + oauth_entry = query.first() + oauth_entry.user = user + ub.session_commit() + + return result + + # notify on OAuth provider error @oauth_error.connect_via(oauthblueprints[0]['blueprint']) def github_error(blueprint, error, error_description=None, error_uri=None): @@ -319,6 +405,20 @@ def google_error(blueprint, error, error_description=None, error_uri=None): flash(msg, category="error") + @oauth_error.connect_via(oauthblueprints[2]['blueprint']) + def generic_error(blueprint, error, error_description=None, error_uri=None): + msg = ( + u"OAuth error from {name}! " + u"error={error} description={description} uri={uri}" + ).format( + name=blueprint.name, + error=error, + description=error_description, + uri=error_uri, + ) # ToDo: Translate + flash(msg, category="error") + + @oauth.route('/link/github') @oauth_required def github_login(): @@ -365,3 +465,41 @@ def google_login(): @login_required def google_login_unlink(): return unlink_oauth(oauthblueprints[1]['id']) + + +@oauth.route('/link/generic') +@oauth_required +def generic_login(): + global generic + + if not generic.session.authorized: + return redirect(url_for("generic.login")) + try: + resp = generic.session.get(generic.base_url + "/protocol/openid-connect/userinfo") + if resp.ok: + account_info_json = resp.json() + + username_mapper = oauthblueprints[2].get('username_mapper') or 'username' + email_mapper = oauthblueprints[2].get('email_mapper') or 'email' + + email = str(account_info_json[email_mapper]) + username = str(account_info_json[username_mapper]) + + user = ( + ub.session.query(ub.User) + .filter(and_(func.lower(ub.User.name) == username, + func.lower(ub.User.email) == email)) + ).first() + + return bind_oauth_or_register(oauthblueprints[2]['id'], user.id, 'generic.login', 'generic') + flash(_(u"generic OAuth2 error, please retry later."), category="error") + log.error("generic OAuth2 error, please retry later") + except (InvalidGrantError, TokenExpiredError) as e: + log.error(e) + return redirect(url_for("generic.login")) + + +@oauth.route('/unlink/generic', methods=["GET"]) +@login_required +def generic_login_unlink(): + return unlink_oauth(oauthblueprints[2]['id']) diff --git a/cps/templates/config_edit.html b/cps/templates/config_edit.html index f61ca9a5ad..966916055a 100644 --- a/cps/templates/config_edit.html +++ b/cps/templates/config_edit.html @@ -293,9 +293,12 @@

{{_('Following Settings are Needed For User Import')}} {% for prov in provider %} +

{{prov['provider_name']}}

+ {% if prov.obtain_link %}
{{_('Obtain %(provider)s OAuth Credential', provider=prov['provider_name'])}}
+ {% endif %}
@@ -304,6 +307,48 @@

{{_('Following Settings are Needed For User Import')}}{{_('%(provider)s OAuth Client Secret', provider=prov['provider_name'])}}

+ {% if 'scope' in prov and 'generic' == prov['provider_name'] %} +
+ + +
+ {% endif %} + {% if 'oauth_base_url' in prov %} +
+ + +
+ {% endif %} + {% if 'oauth_auth_url' in prov %} +
+ + +
+ {% endif %} + {% if 'oauth_token_url' in prov %} +
+ + +
+ {% endif %} + {% if 'username_mapper' in prov %} +
+ + +
+ {% endif %} + {% if 'email_mapper' in prov %} +
+ + +
+ {% endif %} + {% if 'login_button' in prov %} +
+ + +
+ {% endif %} {% endfor %} {% endif %} diff --git a/cps/templates/login.html b/cps/templates/login.html index 708e63ea08..28e436e2c7 100644 --- a/cps/templates/login.html +++ b/cps/templates/login.html @@ -41,6 +41,9 @@

{{_('Login')}}

style="fill:#000000;"> {% endif %} + {% if 3 in oauth_check %} + Log in with {{ login_button }} + {% endif %} {% endif %} diff --git a/cps/ub.py b/cps/ub.py index 6b3c27ca5d..547c004175 100644 --- a/cps/ub.py +++ b/cps/ub.py @@ -250,6 +250,13 @@ class OAuthProvider(Base): provider_name = Column(String) oauth_client_id = Column(String) oauth_client_secret = Column(String) + oauth_base_url = Column(String) + oauth_auth_url = Column(String, default="/protocol/openid-connect/auth") + oauth_token_url = Column(String, default="/protocol/openid-connect/token") + scope = Column(String, default="openid profile email") + username_mapper = Column(String, default="preferred_username") + email_mapper = Column(String, default="email") + login_button = Column(String) active = Column(Boolean) @@ -688,13 +695,13 @@ def migrate_Database(session): "kindle_mail VARCHAR(120)," "locale VARCHAR(2)," "sidebar_view INTEGER," - "default_language VARCHAR(3)," + "default_language VARCHAR(3)," "denied_tags VARCHAR," "allowed_tags VARCHAR," "denied_column_value VARCHAR," "allowed_column_value VARCHAR," "view_settings JSON," - "kobo_only_shelves_sync SMALLINT," + "kobo_only_shelves_sync SMALLINT," "UNIQUE (name)," "UNIQUE (email))")) conn.execute(text("INSERT INTO user_id(id, name, email, role, password, kindle_mail,locale," diff --git a/cps/web.py b/cps/web.py index 82f864886b..5d1b1965f0 100644 --- a/cps/web.py +++ b/cps/web.py @@ -1599,11 +1599,18 @@ def login(): next_url = request.args.get('next', default=url_for("web.index"), type=str) if url_for("web.logout") == next_url: next_url = url_for("web.index") + + login_button = "generic oauth2 provider" + if 3 in oauth_check: + from .oauth_bb import oauthblueprints + login_button = oauthblueprints[2].get('login_button') or login_button + return render_title_template('login.html', title=_(u"Login"), next_url=next_url, config=config, oauth_check=oauth_check, + login_button=login_button, mail=config.get_mail_server_configured(), page="login")