diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index 48ea2d00614fd..0e855e57925d4 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -704,6 +704,11 @@ def auth_roles_mapping(self) -> dict[str, list[str]]: """The mapping of auth roles.""" return self.appbuilder.get_app.config["AUTH_ROLES_MAPPING"] + @property + def auth_user_registration_role_jmespath(self) -> str: + """The JMESPATH role to use for user registration.""" + return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION_ROLE_JMESPATH"] + @property def auth_username_ci(self): """Get the auth username for CI.""" @@ -729,6 +734,10 @@ def auth_role_admin(self): """Get the admin role.""" return self.appbuilder.get_app.config["AUTH_ROLE_ADMIN"] + @property + def oauth_whitelists(self): + return self.oauth_allow_list + def create_builtin_roles(self): """Return FAB builtin roles.""" return self.appbuilder.get_app.config.get("FAB_ROLES", {}) @@ -1933,6 +1942,100 @@ def auth_user_db(self, username, password): log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) return None + def set_oauth_session(self, provider, oauth_response): + """Set the current session with OAuth user secrets.""" + # Get this provider key names for token_key and token_secret + token_key = self.get_oauth_token_key_name(provider) + token_secret = self.get_oauth_token_secret_name(provider) + # Save users token on encrypted session cookie + session["oauth"] = ( + oauth_response[token_key], + oauth_response.get(token_secret, ""), + ) + session["oauth_provider"] = provider + + def get_oauth_token_key_name(self, provider): + """ + Return the token_key name for the oauth provider. + + If none is configured defaults to oauth_token + this is configured using OAUTH_PROVIDERS and token_key key. + """ + for _provider in self.oauth_providers: + if _provider["name"] == provider: + return _provider.get("token_key", "oauth_token") + + def get_oauth_token_secret_name(self, provider): + """ + Get the ``token_secret`` name for the oauth provider. + + If none is configured, defaults to ``oauth_secret``. This is configured + using ``OAUTH_PROVIDERS`` and ``token_secret``. + """ + for _provider in self.oauth_providers: + if _provider["name"] == provider: + return _provider.get("token_secret", "oauth_token_secret") + + def auth_user_oauth(self, userinfo): + """ + Authenticate user with OAuth. + + :userinfo: dict with user information + (keys are the same as User model columns) + """ + # extract the username from `userinfo` + if "username" in userinfo: + username = userinfo["username"] + elif "email" in userinfo: + username = userinfo["email"] + else: + log.error("OAUTH userinfo does not have username or email %s", userinfo) + return None + + # If username is empty, go away + if (username is None) or username == "": + return None + + # Search the DB for this user + user = self.find_user(username=username) + + # If user is not active, go away + if user and (not user.is_active): + return None + + # If user is not registered, and not self-registration, go away + if (not user) and (not self.auth_user_registration): + return None + + # Sync the user's roles + if user and self.auth_roles_sync_at_login: + user.roles = self._oauth_calculate_user_roles(userinfo) + log.debug("Calculated new roles for user=%r as: %s", username, user.roles) + + # If the user is new, register them + if (not user) and self.auth_user_registration: + user = self.add_user( + username=username, + first_name=userinfo.get("first_name", ""), + last_name=userinfo.get("last_name", ""), + email=userinfo.get("email", "") or f"{username}@email.notfound", + role=self._oauth_calculate_user_roles(userinfo), + ) + log.debug("New user registered: %s", user) + + # If user registration failed, go away + if not user: + log.error("Error creating a new OAuth user %s", username) + return None + + # LOGIN SUCCESS (only if user is now registered) + if user: + self._rotate_session_id() + self.update_user_auth_stat(user) + return user + else: + return None + def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]: """ There are different OAuth APIs with different ways to retrieve user info. @@ -2272,3 +2375,31 @@ def _cli_safe_flash(text: str, level: str) -> None: flash(Markup(text), level) else: getattr(log, level)(text.replace("
", "\n").replace("", "*").replace("", "*")) + + def _oauth_calculate_user_roles(self, userinfo) -> list[str]: + user_role_objects = set() + + # apply AUTH_ROLES_MAPPING + if self.auth_roles_mapping: + user_role_keys = userinfo.get("role_keys", []) + user_role_objects.update(self.get_roles_from_keys(user_role_keys)) + + # apply AUTH_USER_REGISTRATION_ROLE + if self.auth_user_registration: + registration_role_name = self.auth_user_registration_role + + # if AUTH_USER_REGISTRATION_ROLE_JMESPATH is set, + # use it for the registration role + if self.auth_user_registration_role_jmespath: + import jmespath + + registration_role_name = jmespath.search(self.auth_user_registration_role_jmespath, userinfo) + + # lookup registration role in flask db + fab_role = self.find_role(registration_role_name) + if fab_role: + user_role_objects.add(fab_role) + else: + log.warning("Can't find AUTH_USER_REGISTRATION role: %s", registration_role_name) + + return list(user_role_objects) diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py index c34018a643688..5c84f90e76153 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -533,6 +533,10 @@ def get_url_for_index(self): def get_url_for_login_with(self, next_url: str | None = None) -> str: return get_auth_manager().get_url_login(next_url=next_url) + @property + def get_url_for_login(self): + return get_auth_manager().get_url_login() + def get_url_for_locale(self, lang): return url_for( f"{self.bm.locale_view.endpoint}.{self.bm.locale_view.default_view}",