Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor implementation of AadOAuth2UserService #32595

Merged
merged 13 commits into from
Dec 16, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,14 @@
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
Expand All @@ -42,21 +38,17 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createOAuth2ErrorResponseHandledRestTemplate;
import static com.azure.spring.cloud.autoconfigure.aad.implementation.constants.Constants.DEFAULT_AUTHORITY_SET;

/**
* This implementation will retrieve group info of user from Microsoft Graph. Then map group to {@link
* GrantedAuthority}.
*
* @see OidcUserService
* @see OAuth2UserService
*/
public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> {

private static final Logger LOGGER = LoggerFactory.getLogger(AadOAuth2UserService.class);

private final OidcUserService oidcUserService;
private final List<String> allowedGroupNames;
private final Set<String> allowedGroupIds;
private final GraphClient graphClient;
Expand All @@ -70,7 +62,7 @@ public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest,
* @param restTemplateBuilder the restTemplateBuilder
*/
public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplateBuilder restTemplateBuilder) {
this(properties, new GraphClient(properties, restTemplateBuilder), restTemplateBuilder);
this(properties, new GraphClient(properties, restTemplateBuilder));
}

/**
Expand All @@ -83,46 +75,77 @@ public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplate
public AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient,
RestTemplateBuilder restTemplateBuilder) {
this(properties, graphClient);
}

private AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient) {
allowedGroupNames = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
allowedGroupIds = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
DefaultOAuth2UserService oAuth2UserService = new DefaultOAuth2UserService();
oAuth2UserService.setRestOperations(createOAuth2ErrorResponseHandledRestTemplate(restTemplateBuilder));
this.oidcUserService = new OidcUserService();
this.oidcUserService.setOauth2UserService(oAuth2UserService);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
this.graphClient = graphClient;
}

/**
* Returns an {@link OAuth2User} after obtaining the user attributes of the End-User
* from the UserInfo Endpoint.
* Returns a {@link DefaultOidcUser} instance.
* <p/>
*
* The {@link DefaultOidcUser} instance is constructed with {@link GrantedAuthority}, {@link OidcIdToken} and nameAttributeKey.
* <a href="https://learn.microsoft.com/azure/active-directory/develop/userinfo#consider-using-an-id-token-instead">Azure AD</a> suggests get userinfo from idToken instead from the UserInfo Endpoint,
* this implementation will not get userinfo from the UserInfo Endpoint. Calling {@link org.springframework.security.oauth2.core.oidc.user.OidcUser#getUserInfo()} with the return instance will return null.
*
* <p/>
*
* @param userRequest the user request
* @return an {@link OAuth2User}
* @throws OAuth2AuthenticationException if an error occurs while attempting to obtain
* the user attributes from the UserInfo Endpoint
*
* @return a {@link DefaultOidcUser} instance.
*
* @throws OAuth2AuthenticationException if an error occurs.
*/
@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This service is an extension for OidcUserService, if we remove the dependency of oidcUserService, it will not be a full Open ID Connect process that Spring Security implements. Some OAuth2AuthenticationExceptions will be ignored the new implementation will not throw OAuth2AuthenticationException .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This service is an extension for OidcUserService

I think it's an implementation of OAuth2UserService.

if we remove the dependency of oidcUserService, it will not be a full Open ID Connect process that Spring Security implements.

Yes, there are differences between our implementation Spring Security .
What dou mean the Open ID Connect process?

Some OAuth2AuthenticationExceptions will be ignored the new implementation will not throw OAuth2AuthenticationException .

So, is there any problem?

OidcIdToken idToken = oidcUser.getIdToken();
Set<String> authorityStrings = new HashSet<>();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Assert.notNull(userRequest, "userRequest cannot be null");

ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
HttpSession session = attr.getRequest().getSession(true);
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();

if (authentication != null) {
LOGGER.debug("User {}'s authorities saved from session: {}.", authentication.getName(), authentication.getAuthorities());
return (DefaultOidcUser) session.getAttribute(DEFAULT_OIDC_USER);
}
Comment on lines 118 to 121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not move this checking up to the top?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it more clear?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the checking for authentication != null can be done early, like the below code:

public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication != null) {
            LOGGER.debug("User {}'s authorities saved from session: {}.", authentication.getName(), authentication.getAuthorities());
            return (DefaultOidcUser) session.getAttribute(DEFAULT_OIDC_USER);
        }

        ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
        HttpSession session = attr.getRequest().getSession(true);
        // ...
        return defaultOidcUser;
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

We should get the session object first.


authorityStrings.addAll(extractRolesFromIdToken(idToken));
DefaultOidcUser defaultOidcUser = getUser(userRequest);
session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
}

DefaultOidcUser getUser(OidcUserRequest userRequest) {
Set<SimpleGrantedAuthority> authorities = buildAuthorities(userRequest);
String nameAttributeKey = getNameAttributeKey(userRequest);
OidcIdToken idToken = userRequest.getIdToken();
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);
return defaultOidcUser;
}

private String getNameAttributeKey(OidcUserRequest userRequest) {
return Optional.of(userRequest)
.map(u -> u.getClientRegistration())
.map(u -> u.getProviderDetails())
.map(u -> u.getUserInfoEndpoint())
.map(u -> u.getUserNameAttributeName())
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
}

private Set<SimpleGrantedAuthority> buildAuthorities(OidcUserRequest userRequest) {
Set<String> authorityStrings = new HashSet<>();
authorityStrings.addAll(extractRolesFromIdToken(userRequest.getIdToken()));
authorityStrings.addAll(extractGroupRolesFromAccessToken(userRequest.getAccessToken()));
Set<SimpleGrantedAuthority> authorities = authorityStrings.stream()
.map(SimpleGrantedAuthority::new)
Expand All @@ -131,20 +154,7 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
if (authorities.isEmpty()) {
authorities = DEFAULT_AUTHORITY_SET;
}
String nameAttributeKey =
Optional.of(userRequest)
.map(OAuth2UserRequest::getClientRegistration)
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
LOGGER.debug("User {}'s authorities extracted by id token and access token: {}.", oidcUser.getClaim(nameAttributeKey), authorities);
// Create a copy of oidcUser but use the mappedAuthorities instead
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);

session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
return authorities;
}

/**
Expand Down
Loading