diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
index 48ea2d00614fd..0e855e57925d4 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -704,6 +704,11 @@ def auth_roles_mapping(self) -> dict[str, list[str]]:
"""The mapping of auth roles."""
return self.appbuilder.get_app.config["AUTH_ROLES_MAPPING"]
+ @property
+ def auth_user_registration_role_jmespath(self) -> str:
+ """The JMESPATH role to use for user registration."""
+ return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION_ROLE_JMESPATH"]
+
@property
def auth_username_ci(self):
"""Get the auth username for CI."""
@@ -729,6 +734,10 @@ def auth_role_admin(self):
"""Get the admin role."""
return self.appbuilder.get_app.config["AUTH_ROLE_ADMIN"]
+ @property
+ def oauth_whitelists(self):
+ return self.oauth_allow_list
+
def create_builtin_roles(self):
"""Return FAB builtin roles."""
return self.appbuilder.get_app.config.get("FAB_ROLES", {})
@@ -1933,6 +1942,100 @@ def auth_user_db(self, username, password):
log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username)
return None
+ def set_oauth_session(self, provider, oauth_response):
+ """Set the current session with OAuth user secrets."""
+ # Get this provider key names for token_key and token_secret
+ token_key = self.get_oauth_token_key_name(provider)
+ token_secret = self.get_oauth_token_secret_name(provider)
+ # Save users token on encrypted session cookie
+ session["oauth"] = (
+ oauth_response[token_key],
+ oauth_response.get(token_secret, ""),
+ )
+ session["oauth_provider"] = provider
+
+ def get_oauth_token_key_name(self, provider):
+ """
+ Return the token_key name for the oauth provider.
+
+ If none is configured defaults to oauth_token
+ this is configured using OAUTH_PROVIDERS and token_key key.
+ """
+ for _provider in self.oauth_providers:
+ if _provider["name"] == provider:
+ return _provider.get("token_key", "oauth_token")
+
+ def get_oauth_token_secret_name(self, provider):
+ """
+ Get the ``token_secret`` name for the oauth provider.
+
+ If none is configured, defaults to ``oauth_secret``. This is configured
+ using ``OAUTH_PROVIDERS`` and ``token_secret``.
+ """
+ for _provider in self.oauth_providers:
+ if _provider["name"] == provider:
+ return _provider.get("token_secret", "oauth_token_secret")
+
+ def auth_user_oauth(self, userinfo):
+ """
+ Authenticate user with OAuth.
+
+ :userinfo: dict with user information
+ (keys are the same as User model columns)
+ """
+ # extract the username from `userinfo`
+ if "username" in userinfo:
+ username = userinfo["username"]
+ elif "email" in userinfo:
+ username = userinfo["email"]
+ else:
+ log.error("OAUTH userinfo does not have username or email %s", userinfo)
+ return None
+
+ # If username is empty, go away
+ if (username is None) or username == "":
+ return None
+
+ # Search the DB for this user
+ user = self.find_user(username=username)
+
+ # If user is not active, go away
+ if user and (not user.is_active):
+ return None
+
+ # If user is not registered, and not self-registration, go away
+ if (not user) and (not self.auth_user_registration):
+ return None
+
+ # Sync the user's roles
+ if user and self.auth_roles_sync_at_login:
+ user.roles = self._oauth_calculate_user_roles(userinfo)
+ log.debug("Calculated new roles for user=%r as: %s", username, user.roles)
+
+ # If the user is new, register them
+ if (not user) and self.auth_user_registration:
+ user = self.add_user(
+ username=username,
+ first_name=userinfo.get("first_name", ""),
+ last_name=userinfo.get("last_name", ""),
+ email=userinfo.get("email", "") or f"{username}@email.notfound",
+ role=self._oauth_calculate_user_roles(userinfo),
+ )
+ log.debug("New user registered: %s", user)
+
+ # If user registration failed, go away
+ if not user:
+ log.error("Error creating a new OAuth user %s", username)
+ return None
+
+ # LOGIN SUCCESS (only if user is now registered)
+ if user:
+ self._rotate_session_id()
+ self.update_user_auth_stat(user)
+ return user
+ else:
+ return None
+
def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]:
"""
There are different OAuth APIs with different ways to retrieve user info.
@@ -2272,3 +2375,31 @@ def _cli_safe_flash(text: str, level: str) -> None:
flash(Markup(text), level)
else:
getattr(log, level)(text.replace("
", "\n").replace("", "*").replace("", "*"))
+
+ def _oauth_calculate_user_roles(self, userinfo) -> list[str]:
+ user_role_objects = set()
+
+ # apply AUTH_ROLES_MAPPING
+ if self.auth_roles_mapping:
+ user_role_keys = userinfo.get("role_keys", [])
+ user_role_objects.update(self.get_roles_from_keys(user_role_keys))
+
+ # apply AUTH_USER_REGISTRATION_ROLE
+ if self.auth_user_registration:
+ registration_role_name = self.auth_user_registration_role
+
+ # if AUTH_USER_REGISTRATION_ROLE_JMESPATH is set,
+ # use it for the registration role
+ if self.auth_user_registration_role_jmespath:
+ import jmespath
+
+ registration_role_name = jmespath.search(self.auth_user_registration_role_jmespath, userinfo)
+
+ # lookup registration role in flask db
+ fab_role = self.find_role(registration_role_name)
+ if fab_role:
+ user_role_objects.add(fab_role)
+ else:
+ log.warning("Can't find AUTH_USER_REGISTRATION role: %s", registration_role_name)
+
+ return list(user_role_objects)
diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py
index c34018a643688..5c84f90e76153 100644
--- a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py
+++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py
@@ -533,6 +533,10 @@ def get_url_for_index(self):
def get_url_for_login_with(self, next_url: str | None = None) -> str:
return get_auth_manager().get_url_login(next_url=next_url)
+ @property
+ def get_url_for_login(self):
+ return get_auth_manager().get_url_login()
+
def get_url_for_locale(self, lang):
return url_for(
f"{self.bm.locale_view.endpoint}.{self.bm.locale_view.default_view}",