diff --git a/cps/admin.py b/cps/admin.py index f104aa29b8..3192eb9ba0 100644 --- a/cps/admin.py +++ b/cps/admin.py @@ -1048,6 +1048,17 @@ def _configuration_oauth_helper(to_save): {"oauth_client_id": to_save["config_" + str(element['id']) + "_oauth_client_id"], "oauth_client_secret": to_save["config_" + str(element['id']) + "_oauth_client_secret"], "active": element["active"]}) + if element['id'] == 3: + ub.session.query(ub.OAuthProvider).filter(ub.OAuthProvider.id == element['id']).update({ + "oauth_base_url": to_save["config_" + str(element['id']) + "_oauth_base_url"], + "oauth_auth_url": to_save["config_" + str(element['id']) + "_oauth_auth_url"], + "oauth_token_url": to_save["config_" + str(element['id']) + "_oauth_token_url"], + "username_mapper": to_save["config_" + str(element['id']) + "_username_mapper"], + "email_mapper": to_save["config_" + str(element['id']) + "_email_mapper"], + "login_button": to_save["config_" + str(element['id']) + "_login_button"], + "scope": to_save["config_" + str(element['id']) + "_scope"], + }) + return reboot_required diff --git a/cps/oauth_bb.py b/cps/oauth_bb.py index d9efd41e41..6ebd504968 100644 --- a/cps/oauth_bb.py +++ b/cps/oauth_bb.py @@ -26,12 +26,13 @@ from flask import session, request, make_response, abort from flask import Blueprint, flash, redirect, url_for from flask_babel import gettext as _ -from flask_dance.consumer import oauth_authorized, oauth_error +from flask_dance.consumer import oauth_authorized, oauth_error, OAuth2ConsumerBlueprint from flask_dance.contrib.github import make_github_blueprint, github from flask_dance.contrib.google import make_google_blueprint, google from oauthlib.oauth2 import TokenExpiredError, InvalidGrantError from flask_login import login_user, current_user, login_required from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.sql.expression import func, and_ from . import constants, logger, config, app, ub @@ -45,6 +46,7 @@ oauthblueprints = [] oauth = Blueprint('oauth', __name__) log = logger.create() +generic = None def oauth_required(f): @@ -95,6 +97,7 @@ def logout_oauth_user(): for oauth_key in oauth_check.keys(): if str(oauth_key) + '_oauth_user_id' in session: session.pop(str(oauth_key) + '_oauth_user_id') + unlink_oauth(oauth_key) def oauth_update_token(provider_id, token, provider_user_id): @@ -206,8 +209,10 @@ def unlink_oauth(provider): return redirect(url_for('web.profile')) def generate_oauth_blueprints(): + global generic + if not ub.session.query(ub.OAuthProvider).count(): - for provider in ("github", "google"): + for provider in ("github", "google", "generic"): oauthProvider = ub.OAuthProvider() oauthProvider.provider_name = provider oauthProvider.active = False @@ -229,20 +234,52 @@ def generate_oauth_blueprints(): oauth_client_id=oauth_ids[1].oauth_client_id, oauth_client_secret=oauth_ids[1].oauth_client_secret, obtain_link='https://console.developers.google.com/apis/credentials') + ele3 = dict(provider_name='generic', + id=oauth_ids[2].id, + active=oauth_ids[2].active, + scope=oauth_ids[2].scope, + oauth_client_id=oauth_ids[2].oauth_client_id, + oauth_client_secret=oauth_ids[2].oauth_client_secret, + oauth_base_url=oauth_ids[2].oauth_base_url, + oauth_auth_url=oauth_ids[2].oauth_auth_url, + oauth_token_url=oauth_ids[2].oauth_token_url, + username_mapper=oauth_ids[2].username_mapper, + email_mapper=oauth_ids[2].email_mapper, + login_button=oauth_ids[2].login_button) oauthblueprints.append(ele1) oauthblueprints.append(ele2) + oauthblueprints.append(ele3) for element in oauthblueprints: if element['provider_name'] == 'github': blueprint_func = make_github_blueprint - else: + elif element['provider_name'] == 'google': blueprint_func = make_google_blueprint - blueprint = blueprint_func( - client_id=element['oauth_client_id'], - client_secret=element['oauth_client_secret'], - redirect_to="oauth."+element['provider_name']+"_login", - scope=element['scope'] - ) + else: + blueprint_func = OAuth2ConsumerBlueprint + + if element['provider_name'] in ('github', 'google'): + blueprint = blueprint_func( + client_id=element['oauth_client_id'], + client_secret=element['oauth_client_secret'], + redirect_url="oauth."+element['provider_name']+"_login", + scope=element['scope'] + ) + else: + base_url = element.get('oauth_base_url') or '' + token_url = element.get('oauth_token_url') or '' + auth_url = element.get('oauth_auth_url') or '' + blueprint = blueprint_func( + "generic", + __name__, + client_id=element['oauth_client_id'], + client_secret=element['oauth_client_secret'], + base_url=base_url, + authorization_url=base_url + auth_url, + token_url=base_url + token_url, + redirect_to='oauth.'+element['provider_name']+'_login', + ) + generic = blueprint element['blueprint'] = blueprint element['blueprint'].backend = OAuthBackend(ub.OAuth, ub.session, str(element['id']), user=current_user, user_required=True) @@ -291,6 +328,55 @@ def google_logged_in(blueprint, token): return oauth_update_token(str(oauthblueprints[1]['id']), token, google_user_id) + @oauth_authorized.connect_via(oauthblueprints[2]['blueprint']) + def generic_logged_in(blueprint, token): + global generic + + if not token: + flash(_(u"Failed to log in with generic OAuth provider."), category="error") + log.error("Failed to log in with generic OAuth2 provider") + return False + + resp = blueprint.session.get(blueprint.base_url + "/protocol/openid-connect/userinfo") + if not resp.ok: + flash(_(u"Failed to fetch user info from generic OAuth2 provider."), category="error") + log.error("Failed to fetch user info from generic OAuth2 provider") + return False + + username_mapper = oauthblueprints[2].get('username_mapper') or 'username' + email_mapper = oauthblueprints[2].get('email_mapper') or 'email' + + generic_info = resp.json() + generic_user_email = str(generic_info[email_mapper]) + generic_user_username = str(generic_info[username_mapper]) + + user = ( + ub.session.query(ub.User) + .filter(and_(func.lower(ub.User.name) == generic_user_username, + func.lower(ub.User.email) == generic_user_email)) + ).first() + + if user is None: + user = ub.User() + user.name = generic_user_username + user.email = generic_user_email + user.role = constants.ROLE_USER + ub.session.add(user) + ub.session_commit() + + result = oauth_update_token(str(oauthblueprints[2]['id']), token, user.id) + + query = ub.session.query(ub.OAuth).filter_by( + provider=str(oauthblueprints[2]['id']), + provider_user_id=user.id, + ) + oauth_entry = query.first() + oauth_entry.user = user + ub.session_commit() + + return result + + # notify on OAuth provider error @oauth_error.connect_via(oauthblueprints[0]['blueprint']) def github_error(blueprint, error, error_description=None, error_uri=None): @@ -319,6 +405,20 @@ def google_error(blueprint, error, error_description=None, error_uri=None): flash(msg, category="error") + @oauth_error.connect_via(oauthblueprints[2]['blueprint']) + def generic_error(blueprint, error, error_description=None, error_uri=None): + msg = ( + u"OAuth error from {name}! " + u"error={error} description={description} uri={uri}" + ).format( + name=blueprint.name, + error=error, + description=error_description, + uri=error_uri, + ) # ToDo: Translate + flash(msg, category="error") + + @oauth.route('/link/github') @oauth_required def github_login(): @@ -365,3 +465,41 @@ def google_login(): @login_required def google_login_unlink(): return unlink_oauth(oauthblueprints[1]['id']) + + +@oauth.route('/link/generic') +@oauth_required +def generic_login(): + global generic + + if not generic.session.authorized: + return redirect(url_for("generic.login")) + try: + resp = generic.session.get(generic.base_url + "/protocol/openid-connect/userinfo") + if resp.ok: + account_info_json = resp.json() + + username_mapper = oauthblueprints[2].get('username_mapper') or 'username' + email_mapper = oauthblueprints[2].get('email_mapper') or 'email' + + email = str(account_info_json[email_mapper]) + username = str(account_info_json[username_mapper]) + + user = ( + ub.session.query(ub.User) + .filter(and_(func.lower(ub.User.name) == username, + func.lower(ub.User.email) == email)) + ).first() + + return bind_oauth_or_register(oauthblueprints[2]['id'], user.id, 'generic.login', 'generic') + flash(_(u"generic OAuth2 error, please retry later."), category="error") + log.error("generic OAuth2 error, please retry later") + except (InvalidGrantError, TokenExpiredError) as e: + log.error(e) + return redirect(url_for("generic.login")) + + +@oauth.route('/unlink/generic', methods=["GET"]) +@login_required +def generic_login_unlink(): + return unlink_oauth(oauthblueprints[2]['id']) diff --git a/cps/templates/config_edit.html b/cps/templates/config_edit.html index f61ca9a5ad..966916055a 100644 --- a/cps/templates/config_edit.html +++ b/cps/templates/config_edit.html @@ -293,9 +293,12 @@