Skip to content

Commit

Permalink
Refactor implementation of AadOAuth2UserService (#32595)
Browse files Browse the repository at this point in the history
  • Loading branch information
backwind1233 authored Dec 16, 2022
1 parent 3ebbd3b commit 71029d5
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,14 @@
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
Expand All @@ -42,21 +38,17 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createOAuth2ErrorResponseHandledRestTemplate;
import static com.azure.spring.cloud.autoconfigure.aad.implementation.constants.Constants.DEFAULT_AUTHORITY_SET;

/**
* This implementation will retrieve group info of user from Microsoft Graph. Then map group to {@link
* GrantedAuthority}.
*
* @see OidcUserService
* @see OAuth2UserService
*/
public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> {

private static final Logger LOGGER = LoggerFactory.getLogger(AadOAuth2UserService.class);

private final OidcUserService oidcUserService;
private final List<String> allowedGroupNames;
private final Set<String> allowedGroupIds;
private final GraphClient graphClient;
Expand All @@ -70,7 +62,7 @@ public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest,
* @param restTemplateBuilder the restTemplateBuilder
*/
public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplateBuilder restTemplateBuilder) {
this(properties, new GraphClient(properties, restTemplateBuilder), restTemplateBuilder);
this(properties, new GraphClient(properties, restTemplateBuilder));
}

/**
Expand All @@ -83,46 +75,77 @@ public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplate
public AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient,
RestTemplateBuilder restTemplateBuilder) {
this(properties, graphClient);
}

private AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient) {
allowedGroupNames = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
allowedGroupIds = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
DefaultOAuth2UserService oAuth2UserService = new DefaultOAuth2UserService();
oAuth2UserService.setRestOperations(createOAuth2ErrorResponseHandledRestTemplate(restTemplateBuilder));
this.oidcUserService = new OidcUserService();
this.oidcUserService.setOauth2UserService(oAuth2UserService);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
this.graphClient = graphClient;
}

/**
* Returns an {@link OAuth2User} after obtaining the user attributes of the End-User
* from the UserInfo Endpoint.
* Returns a {@link DefaultOidcUser} instance.
* <p/>
*
* The {@link DefaultOidcUser} instance is constructed with {@link GrantedAuthority}, {@link OidcIdToken} and nameAttributeKey.
* <a href="https://learn.microsoft.com/azure/active-directory/develop/userinfo#consider-using-an-id-token-instead">Azure AD</a> suggests get userinfo from idToken instead from the UserInfo Endpoint,
* this implementation will not get userinfo from the UserInfo Endpoint. Calling {@link org.springframework.security.oauth2.core.oidc.user.OidcUser#getUserInfo()} with the return instance will return null.
*
* <p/>
*
* @param userRequest the user request
* @return an {@link OAuth2User}
* @throws OAuth2AuthenticationException if an error occurs while attempting to obtain
* the user attributes from the UserInfo Endpoint
*
* @return a {@link DefaultOidcUser} instance.
*
* @throws OAuth2AuthenticationException if an error occurs.
*/
@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
OidcIdToken idToken = oidcUser.getIdToken();
Set<String> authorityStrings = new HashSet<>();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Assert.notNull(userRequest, "userRequest cannot be null");

ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
HttpSession session = attr.getRequest().getSession(true);
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();

if (authentication != null) {
LOGGER.debug("User {}'s authorities saved from session: {}.", authentication.getName(), authentication.getAuthorities());
return (DefaultOidcUser) session.getAttribute(DEFAULT_OIDC_USER);
}

authorityStrings.addAll(extractRolesFromIdToken(idToken));
DefaultOidcUser defaultOidcUser = getUser(userRequest);
session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
}

DefaultOidcUser getUser(OidcUserRequest userRequest) {
Set<SimpleGrantedAuthority> authorities = buildAuthorities(userRequest);
String nameAttributeKey = getNameAttributeKey(userRequest);
OidcIdToken idToken = userRequest.getIdToken();
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);
return defaultOidcUser;
}

private String getNameAttributeKey(OidcUserRequest userRequest) {
return Optional.of(userRequest)
.map(u -> u.getClientRegistration())
.map(u -> u.getProviderDetails())
.map(u -> u.getUserInfoEndpoint())
.map(u -> u.getUserNameAttributeName())
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
}

private Set<SimpleGrantedAuthority> buildAuthorities(OidcUserRequest userRequest) {
Set<String> authorityStrings = new HashSet<>();
authorityStrings.addAll(extractRolesFromIdToken(userRequest.getIdToken()));
authorityStrings.addAll(extractGroupRolesFromAccessToken(userRequest.getAccessToken()));
Set<SimpleGrantedAuthority> authorities = authorityStrings.stream()
.map(SimpleGrantedAuthority::new)
Expand All @@ -131,20 +154,7 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
if (authorities.isEmpty()) {
authorities = DEFAULT_AUTHORITY_SET;
}
String nameAttributeKey =
Optional.of(userRequest)
.map(OAuth2UserRequest::getClientRegistration)
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
LOGGER.debug("User {}'s authorities extracted by id token and access token: {}.", oidcUser.getClaim(nameAttributeKey), authorities);
// Create a copy of oidcUser but use the mappedAuthorities instead
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);

session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
return authorities;
}

/**
Expand Down
Loading

0 comments on commit 71029d5

Please sign in to comment.