Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,11 @@ def auth_roles_mapping(self) -> dict[str, list[str]]:
"""The mapping of auth roles."""
return self.appbuilder.get_app.config["AUTH_ROLES_MAPPING"]

@property
def auth_user_registration_role_jmespath(self) -> str:
"""The JMESPATH role to use for user registration."""
return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION_ROLE_JMESPATH"]

@property
def auth_username_ci(self):
"""Get the auth username for CI."""
Expand All @@ -729,6 +734,10 @@ def auth_role_admin(self):
"""Get the admin role."""
return self.appbuilder.get_app.config["AUTH_ROLE_ADMIN"]

@property
def oauth_whitelists(self):
return self.oauth_allow_list

def create_builtin_roles(self):
"""Return FAB builtin roles."""
return self.appbuilder.get_app.config.get("FAB_ROLES", {})
Expand Down Expand Up @@ -1933,6 +1942,100 @@ def auth_user_db(self, username, password):
log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username)
return None

def set_oauth_session(self, provider, oauth_response):
"""Set the current session with OAuth user secrets."""
# Get this provider key names for token_key and token_secret
token_key = self.get_oauth_token_key_name(provider)
token_secret = self.get_oauth_token_secret_name(provider)
# Save users token on encrypted session cookie
session["oauth"] = (
oauth_response[token_key],
oauth_response.get(token_secret, ""),
)
session["oauth_provider"] = provider

def get_oauth_token_key_name(self, provider):
"""
Return the token_key name for the oauth provider.

If none is configured defaults to oauth_token
this is configured using OAUTH_PROVIDERS and token_key key.
"""
for _provider in self.oauth_providers:
if _provider["name"] == provider:
return _provider.get("token_key", "oauth_token")

def get_oauth_token_secret_name(self, provider):
"""
Get the ``token_secret`` name for the oauth provider.

If none is configured, defaults to ``oauth_secret``. This is configured
using ``OAUTH_PROVIDERS`` and ``token_secret``.
"""
for _provider in self.oauth_providers:
if _provider["name"] == provider:
return _provider.get("token_secret", "oauth_token_secret")

def auth_user_oauth(self, userinfo):
"""
Authenticate user with OAuth.

:userinfo: dict with user information
(keys are the same as User model columns)
"""
# extract the username from `userinfo`
if "username" in userinfo:
username = userinfo["username"]
elif "email" in userinfo:
username = userinfo["email"]
else:
log.error("OAUTH userinfo does not have username or email %s", userinfo)
return None

# If username is empty, go away
if (username is None) or username == "":
return None

# Search the DB for this user
user = self.find_user(username=username)

# If user is not active, go away
if user and (not user.is_active):
return None

# If user is not registered, and not self-registration, go away
if (not user) and (not self.auth_user_registration):
return None

# Sync the user's roles
if user and self.auth_roles_sync_at_login:
user.roles = self._oauth_calculate_user_roles(userinfo)
log.debug("Calculated new roles for user=%r as: %s", username, user.roles)

# If the user is new, register them
if (not user) and self.auth_user_registration:
user = self.add_user(
username=username,
first_name=userinfo.get("first_name", ""),
last_name=userinfo.get("last_name", ""),
email=userinfo.get("email", "") or f"{username}@email.notfound",
role=self._oauth_calculate_user_roles(userinfo),
)
log.debug("New user registered: %s", user)

# If user registration failed, go away
if not user:
log.error("Error creating a new OAuth user %s", username)
return None

# LOGIN SUCCESS (only if user is now registered)
if user:
self._rotate_session_id()
self.update_user_auth_stat(user)
return user
else:
return None

def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]:
"""
There are different OAuth APIs with different ways to retrieve user info.
Expand Down Expand Up @@ -2272,3 +2375,31 @@ def _cli_safe_flash(text: str, level: str) -> None:
flash(Markup(text), level)
else:
getattr(log, level)(text.replace("<br>", "\n").replace("<b>", "*").replace("</b>", "*"))

def _oauth_calculate_user_roles(self, userinfo) -> list[str]:
user_role_objects = set()

# apply AUTH_ROLES_MAPPING
if self.auth_roles_mapping:
user_role_keys = userinfo.get("role_keys", [])
user_role_objects.update(self.get_roles_from_keys(user_role_keys))

# apply AUTH_USER_REGISTRATION_ROLE
if self.auth_user_registration:
registration_role_name = self.auth_user_registration_role

# if AUTH_USER_REGISTRATION_ROLE_JMESPATH is set,
# use it for the registration role
if self.auth_user_registration_role_jmespath:
import jmespath

registration_role_name = jmespath.search(self.auth_user_registration_role_jmespath, userinfo)

# lookup registration role in flask db
fab_role = self.find_role(registration_role_name)
if fab_role:
user_role_objects.add(fab_role)
else:
log.warning("Can't find AUTH_USER_REGISTRATION role: %s", registration_role_name)

return list(user_role_objects)
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@ def get_url_for_index(self):
def get_url_for_login_with(self, next_url: str | None = None) -> str:
return get_auth_manager().get_url_login(next_url=next_url)

@property
def get_url_for_login(self):
return get_auth_manager().get_url_login()

def get_url_for_locale(self, lang):
return url_for(
f"{self.bm.locale_view.endpoint}.{self.bm.locale_view.default_view}",
Expand Down