Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor implementation of AadOAuth2UserService #32595

Merged
merged 13 commits into from
Dec 16, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,14 @@
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
Expand All @@ -42,21 +38,17 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createOAuth2ErrorResponseHandledRestTemplate;
import static com.azure.spring.cloud.autoconfigure.aad.implementation.constants.Constants.DEFAULT_AUTHORITY_SET;

/**
* This implementation will retrieve group info of user from Microsoft Graph. Then map group to {@link
* GrantedAuthority}.
*
* @see OidcUserService
* @see OAuth2UserService
*/
public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> {

private static final Logger LOGGER = LoggerFactory.getLogger(AadOAuth2UserService.class);

private final OidcUserService oidcUserService;
private final List<String> allowedGroupNames;
private final Set<String> allowedGroupIds;
private final GraphClient graphClient;
Expand All @@ -70,7 +62,7 @@ public class AadOAuth2UserService implements OAuth2UserService<OidcUserRequest,
* @param restTemplateBuilder the restTemplateBuilder
*/
public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplateBuilder restTemplateBuilder) {
this(properties, new GraphClient(properties, restTemplateBuilder), restTemplateBuilder);
this(properties, new GraphClient(properties, restTemplateBuilder));
}

/**
Expand All @@ -83,46 +75,77 @@ public AadOAuth2UserService(AadAuthenticationProperties properties, RestTemplate
public AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient,
RestTemplateBuilder restTemplateBuilder) {
this(properties, graphClient);
}

private AadOAuth2UserService(AadAuthenticationProperties properties,
GraphClient graphClient) {
allowedGroupNames = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupNames)
.orElseGet(Collections::emptyList);
allowedGroupIds = Optional.ofNullable(properties)
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
DefaultOAuth2UserService oAuth2UserService = new DefaultOAuth2UserService();
oAuth2UserService.setRestOperations(createOAuth2ErrorResponseHandledRestTemplate(restTemplateBuilder));
this.oidcUserService = new OidcUserService();
this.oidcUserService.setOauth2UserService(oAuth2UserService);
.map(AadAuthenticationProperties::getUserGroup)
.map(AadAuthenticationProperties.UserGroupProperties::getAllowedGroupIds)
.orElseGet(Collections::emptySet);
this.graphClient = graphClient;
}

/**
* Returns an {@link OAuth2User} after obtaining the user attributes of the End-User
* from the UserInfo Endpoint.
* Returns a {@link DefaultOidcUser} instance.
* <p/>
*
* The {@link DefaultOidcUser} instance is constructed with {@link GrantedAuthority}, {@link OidcIdToken} and nameAttributeKey.
* <a href="https://learn.microsoft.com/azure/active-directory/develop/userinfo#consider-using-an-id-token-instead">Azure AD</a> suggests get userinfo from idToken instead from the UserInfo Endpoint,
* this implementation will not get userinfo from the UserInfo Endpoint. Calling {@link org.springframework.security.oauth2.core.oidc.user.OidcUser#getUserInfo()} with the return instance will return null.
*
* <p/>
*
* @param userRequest the user request
* @return an {@link OAuth2User}
* @throws OAuth2AuthenticationException if an error occurs while attempting to obtain
* the user attributes from the UserInfo Endpoint
*
* @return a {@link DefaultOidcUser} instance.
*
* @throws OAuth2AuthenticationException if an error occurs.
*/
@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This service is an extension for OidcUserService, if we remove the dependency of oidcUserService, it will not be a full Open ID Connect process that Spring Security implements. Some OAuth2AuthenticationExceptions will be ignored the new implementation will not throw OAuth2AuthenticationException .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This service is an extension for OidcUserService

I think it's an implementation of OAuth2UserService.

if we remove the dependency of oidcUserService, it will not be a full Open ID Connect process that Spring Security implements.

Yes, there are differences between our implementation Spring Security .
What dou mean the Open ID Connect process?

Some OAuth2AuthenticationExceptions will be ignored the new implementation will not throw OAuth2AuthenticationException .

So, is there any problem?

OidcIdToken idToken = oidcUser.getIdToken();
Set<String> authorityStrings = new HashSet<>();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Assert.notNull(userRequest, "userRequest cannot be null");

ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
HttpSession session = attr.getRequest().getSession(true);
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();

if (authentication != null) {
LOGGER.debug("User {}'s authorities saved from session: {}.", authentication.getName(), authentication.getAuthorities());
return (DefaultOidcUser) session.getAttribute(DEFAULT_OIDC_USER);
}
Comment on lines 118 to 121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not move this checking up to the top?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it more clear?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the checking for authentication != null can be done early, like the below code:

public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication != null) {
            LOGGER.debug("User {}'s authorities saved from session: {}.", authentication.getName(), authentication.getAuthorities());
            return (DefaultOidcUser) session.getAttribute(DEFAULT_OIDC_USER);
        }

        ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
        HttpSession session = attr.getRequest().getSession(true);
        // ...
        return defaultOidcUser;
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

We should get the session object first.


authorityStrings.addAll(extractRolesFromIdToken(idToken));
DefaultOidcUser defaultOidcUser = getUser(userRequest);
session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
}

DefaultOidcUser getUser(OidcUserRequest userRequest) {
Set<SimpleGrantedAuthority> authorities = buildAuthorities(userRequest);
String nameAttributeKey = getNameAttributeKey(userRequest);
OidcIdToken idToken = userRequest.getIdToken();
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);
return defaultOidcUser;
}

private String getNameAttributeKey(OidcUserRequest userRequest) {
return Optional.of(userRequest)
.map(u -> u.getClientRegistration())
.map(u -> u.getProviderDetails())
.map(u -> u.getUserInfoEndpoint())
.map(u -> u.getUserNameAttributeName())
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
}

private Set<SimpleGrantedAuthority> buildAuthorities(OidcUserRequest userRequest) {
Set<String> authorityStrings = new HashSet<>();
authorityStrings.addAll(extractRolesFromIdToken(userRequest.getIdToken()));
authorityStrings.addAll(extractGroupRolesFromAccessToken(userRequest.getAccessToken()));
Set<SimpleGrantedAuthority> authorities = authorityStrings.stream()
.map(SimpleGrantedAuthority::new)
Expand All @@ -131,20 +154,7 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
if (authorities.isEmpty()) {
authorities = DEFAULT_AUTHORITY_SET;
}
String nameAttributeKey =
Optional.of(userRequest)
.map(OAuth2UserRequest::getClientRegistration)
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(StringUtils::hasText)
.orElse(AadJwtClaimNames.NAME);
LOGGER.debug("User {}'s authorities extracted by id token and access token: {}.", oidcUser.getClaim(nameAttributeKey), authorities);
// Create a copy of oidcUser but use the mappedAuthorities instead
DefaultOidcUser defaultOidcUser = new DefaultOidcUser(authorities, idToken, nameAttributeKey);

session.setAttribute(DEFAULT_OIDC_USER, defaultOidcUser);
return defaultOidcUser;
return authorities;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.cloud.autoconfigure.aad.implementation.webapp;

import com.azure.spring.cloud.autoconfigure.aad.implementation.constants.AuthorityPrefix;
import com.azure.spring.cloud.autoconfigure.aad.implementation.graph.GraphClient;
import com.azure.spring.cloud.autoconfigure.aad.implementation.graph.GroupInformation;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthenticationProperties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpSession;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests for {@link AadOAuth2UserService}.
*/
class AadOAuth2UserServiceTest {
private ClientRegistration.Builder clientRegistrationBuilder;
private OidcIdToken idToken;
private AadOAuth2UserService aadOAuth2UserService;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems can be local variable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think loadUserWithCustomAuthorities() is the special one.
Other test cases could leverage the code in setup().

private OAuth2AccessToken accessToken;
private Map<String, Object> idTokenClaims = new HashMap<>();
private GraphClient graphClient;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this graph client

private AadAuthenticationProperties properties;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for these two

Copy link
Contributor Author

@backwind1233 backwind1233 Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a little change.


private static final String DEFAULT_OIDC_USER = "defaultOidcUser";


@BeforeEach
void setup() {
saragluna marked this conversation as resolved.
Show resolved Hide resolved

clientRegistrationBuilder = ClientRegistration
.withRegistrationId("registrationId")
.clientName("registrationId")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("redirectUri")
.userInfoUri(null)
.clientId("cliendId")
.clientSecret("clientSecret")
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.authorizationUri("authorizationUri")
.tokenUri("tokenUri");

this.accessToken = TestOAuth2AccessTokens.scopes(OidcScopes.OPENID, OidcScopes.PROFILE);

idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
idTokenClaims.put(StandardClaimNames.NAME, "user1");
idTokenClaims.put(StandardClaimNames.EMAIL, "user1@example.com");

this.idToken = new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this idToken isn't used by each test case, we can consider making this instantiated in each test method instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think loadUserWithCustomAuthorities() is the special one.
Other test cases could leverage the code in setup().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a variable is not fit for all cases, we should narrow the scope to method. It's okay to have some duplication in the UT. Which will make each test case easy to read.


}

@Test
void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);
assertThatIllegalArgumentException().isThrownBy(() -> this.aadOAuth2UserService.loadUser(null));
}

@Test
void loadUserFromSession() {
//given
ServletRequestAttributes mockAttributes = mock(ServletRequestAttributes.class, RETURNS_DEEP_STUBS);
DefaultOidcUser mockDefaultOidcUser = mock(DefaultOidcUser.class);
HttpSession mockHttpSession = mock(HttpSession.class);
when(mockHttpSession.getAttribute(DEFAULT_OIDC_USER)).thenReturn(mockDefaultOidcUser);
Authentication mockAuthentication = mock(Authentication.class);

when(mockAttributes.getRequest().getSession(true)).thenReturn(mockHttpSession);

RequestContextHolder.setRequestAttributes(mockAttributes);
SecurityContextHolder.getContext().setAuthentication(mockAuthentication);

aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.build();

// when
OidcUser user = aadOAuth2UserService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));

// then
assertThat(user).isEqualTo(mockDefaultOidcUser);
}

@Test
void loadUserWithDefaultAuthority() {
aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);

ClientRegistration clientRegistration = this.clientRegistrationBuilder
.build();
OidcUser user = aadOAuth2UserService
.getUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));

assertThat(user.getUserInfo()).isNull();
assertThat(user.getClaims()).isEqualTo(idTokenClaims);
assertThat(user.getAuthorities().size()).isEqualTo(1);
SimpleGrantedAuthority defaultGrantedAuthority = new SimpleGrantedAuthority(AuthorityPrefix.ROLE + "USER");
assertThat(user.getAuthorities().stream().findFirst().get()).isEqualTo(defaultGrantedAuthority);
}

@Test
void loadUserWithCustomAuthorities() {

idTokenClaims.put("roles", Stream.of("role1", "role2")
.collect(Collectors.toList()));

GroupInformation groupInformation = new GroupInformation();
groupInformation.setGroupsIds(Stream.of("groupId1", "groupId2")
.collect(Collectors.toSet()));
groupInformation.setGroupsNames(Stream.of("groupName1", "groupName2")
.collect(Collectors.toSet()));
graphClient = mock(GraphClient.class);
when(graphClient.getGroupInformation(anyString())).thenReturn(groupInformation);

properties = new AadAuthenticationProperties();
properties.getUserGroup().setAllowedGroupNames(Stream.of("groupName1", "groupName2")
.collect(Collectors.toList()));
properties.getUserGroup().setAllowedGroupIds(Stream.of("groupId1", "groupId2")
.collect(Collectors.toSet()));

aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);

ClientRegistration clientRegistration = this.clientRegistrationBuilder
.build();

OidcUser user = this.aadOAuth2UserService
.getUser(new OidcUserRequest(clientRegistration, this.accessToken,
new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims)));

assertThat(user.getUserInfo()).isNull();
assertThat(user.getClaims()).isEqualTo(idTokenClaims);
assertThat(user.getAuthorities().size()).isEqualTo(6);
Set<SimpleGrantedAuthority> simpleGrantedAuthorities
= Stream.of(new SimpleGrantedAuthority("APPROLE_role1"),
new SimpleGrantedAuthority("APPROLE_role2"),
new SimpleGrantedAuthority("ROLE_groupId1"),
new SimpleGrantedAuthority("ROLE_groupId2"),
new SimpleGrantedAuthority("ROLE_groupName1"),
new SimpleGrantedAuthority("ROLE_groupName2"))
.collect(Collectors.toSet());
assertThat(user.getAuthorities()).isEqualTo(simpleGrantedAuthorities);
}

@Test
void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() {
aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);

ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userNameAttributeName(StandardClaimNames.EMAIL)
.build();

OidcUser user = this.aadOAuth2UserService
.getUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getName()).isEqualTo("user1@example.com");
}

@Test
void loadUserWithDefaultUserNameAttributeName() {
aadOAuth2UserService = new AadOAuth2UserService(properties, graphClient, null);

ClientRegistration clientRegistration = this.clientRegistrationBuilder
.build();

OidcUser user = this.aadOAuth2UserService
.getUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getName()).isEqualTo("user1");
}

static class TestOAuth2AccessTokens {

private TestOAuth2AccessTokens() {
}

public static OAuth2AccessToken noScopes() {
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "no-scopes", Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
}

public static OAuth2AccessToken scopes(String... scopes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for public

return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", Instant.now(),
Instant.now().plus(Duration.ofDays(1)), new HashSet<>(Arrays.asList(scopes)));
}

}

}