Skip to content

Commit

Permalink
In Oauth2UserService, append authorities instead of override authorit…
Browse files Browse the repository at this point in the history
…ies (#17838)

* In XxxOAuth2UserService, append authorities instead of override authorities.
  • Loading branch information
Rujun Chen authored Nov 26, 2020
1 parent 3986906 commit edf5460
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@

import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties;
import com.azure.spring.autoconfigure.aad.AADTokenClaim;
import com.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory;
import com.azure.spring.autoconfigure.aad.Membership;
import com.azure.spring.autoconfigure.aad.Memberships;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
Expand All @@ -26,131 +16,63 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.util.StringUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashSet;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.azure.spring.autoconfigure.aad.Constants.DEFAULT_AUTHORITY_SET;
import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX;

/**
* This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link
* GrantedAuthority}.
*/
public class AzureActiveDirectoryOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> {
private static final Logger LOGGER = LoggerFactory.getLogger(AzureActiveDirectoryOAuth2UserService.class);

private final OidcUserService oidcUserService;
private final AADAuthenticationProperties properties;
private final GraphClient graphClient;

public AzureActiveDirectoryOAuth2UserService(
AADAuthenticationProperties properties
) {
this.properties = properties;
this.oidcUserService = new OidcUserService();
this.graphClient = new GraphClient(properties);
}

@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
Set<SimpleGrantedAuthority> authorities =
Optional.of(userRequest)
.map(OAuth2UserRequest::getAccessToken)
.map(AbstractOAuth2Token::getTokenValue)
.map(this::getGroups)
.map(this::toGrantedAuthoritySet)
.filter(g -> !g.isEmpty())
.orElse(DEFAULT_AUTHORITY_SET);
Set<String> groups = Optional.of(userRequest)
.map(OAuth2UserRequest::getAccessToken)
.map(AbstractOAuth2Token::getTokenValue)
.map(graphClient::getGroupsFromGraph)
.orElseGet(Collections::emptySet);
Set<String> groupRoles = groups.stream()
.filter(properties::isAllowedGroup)
.map(group -> ROLE_PREFIX + group)
.collect(Collectors.toSet());
Set<String> allRoles = oidcUser.getAuthorities()
.stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
allRoles.addAll(groupRoles);
Set<SimpleGrantedAuthority> authorities = allRoles.stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
String nameAttributeKey =
Optional.of(userRequest)
.map(OAuth2UserRequest::getClientRegistration)
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(s -> !s.isEmpty())
.filter(StringUtils::hasText)
.orElse(AADTokenClaim.NAME);
// Create a copy of oidcUser but use the mappedAuthorities instead
return new DefaultOidcUser(authorities, oidcUser.getIdToken(), nameAttributeKey);
}

public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) {
Set<SimpleGrantedAuthority> grantedAuthoritySet =
groups.stream()
.filter(properties::isAllowedGroup)
.map(group -> new SimpleGrantedAuthority(ROLE_PREFIX + group))
.collect(Collectors.toSet());
return Optional.of(grantedAuthoritySet)
.filter(g -> !g.isEmpty())
.orElse(DEFAULT_AUTHORITY_SET);
}

public Set<String> getGroups(String accessToken) {
final Set<String> groups = new LinkedHashSet<>();
final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
String aadMembershipRestUri = properties.getGraphMembershipUri();
while (aadMembershipRestUri != null) {
Memberships memberships;
try {
String membershipsJson = getUserMemberships(accessToken, aadMembershipRestUri);
memberships = objectMapper.readValue(membershipsJson, Memberships.class);
} catch (IOException ioException) {
LOGGER.error("Can not get group information from graph server.", ioException);
break;
}
memberships.getValue()
.stream()
.filter(this::isGroupObject)
.map(Membership::getDisplayName)
.forEach(groups::add);
aadMembershipRestUri = Optional.of(memberships)
.map(Memberships::getOdataNextLink)
.orElse(null);
}
return groups;
}

private String getUserMemberships(String accessToken, String urlString) throws IOException {
URL url = new URL(urlString);
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
final String responseInJson = getResponseString(connection);
final int responseCode = connection.getResponseCode();
if (responseCode == HTTPResponse.SC_OK) {
return responseInJson;
} else {
throw new IllegalStateException(
"Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson);
}
}

private String getResponseString(HttpURLConnection connection) throws IOException {
try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(connection.getInputStream(),
StandardCharsets.UTF_8))
) {
final StringBuilder stringBuffer = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
stringBuffer.append(line);
}
return stringBuffer.toString();
}
}

private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(properties.getUserGroup().getValue());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.aad.implementation;

import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties;
import com.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory;
import com.azure.spring.autoconfigure.aad.Membership;
import com.azure.spring.autoconfigure.aad.Memberships;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashSet;
import java.util.Optional;
import java.util.Set;

public class GraphClient {
private static final Logger LOGGER = LoggerFactory.getLogger(GraphClient.class);

private final AADAuthenticationProperties properties;

public GraphClient(AADAuthenticationProperties properties) {
this.properties = properties;
}

public Set<String> getGroupsFromGraph(String accessToken) {
final Set<String> groups = new LinkedHashSet<>();
final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
String aadMembershipRestUri = properties.getGraphMembershipUri();
while (aadMembershipRestUri != null) {
Memberships memberships;
try {
String membershipsJson = getUserMemberships(accessToken, aadMembershipRestUri);
memberships = objectMapper.readValue(membershipsJson, Memberships.class);
} catch (IOException ioException) {
LOGGER.error("Can not get group information from graph server.", ioException);
break;
}
memberships.getValue()
.stream()
.filter(this::isGroupObject)
.map(Membership::getDisplayName)
.forEach(groups::add);
aadMembershipRestUri = Optional.of(memberships)
.map(Memberships::getOdataNextLink)
.orElse(null);
}
return groups;
}

private String getUserMemberships(String accessToken, String urlString) throws IOException {
URL url = new URL(urlString);
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
final String responseInJson = getResponseString(connection);
final int responseCode = connection.getResponseCode();
if (responseCode == HTTPResponse.SC_OK) {
return responseInJson;
} else {
throw new IllegalStateException(
"Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson);
}
}

private String getResponseString(HttpURLConnection connection) throws IOException {
try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8)
)
) {
final StringBuilder stringBuffer = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
stringBuffer.append(line);
}
return stringBuffer.toString();
}
}

private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(properties.getUserGroup().getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.util.StringUtils;

import javax.naming.ServiceUnavailableException;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.CONDITIONAL_ACCESS_POLICY;
import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.INVALID_REQUEST;
import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.SERVER_SERVER;
import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX;

/**
* This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link
Expand All @@ -46,7 +49,7 @@ public AADOAuth2UserService(AADAuthenticationProperties aadAuthenticationPropert
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
final Set<SimpleGrantedAuthority> mappedAuthorities;
final Set<SimpleGrantedAuthority> authorities;
try {
// https://github.com/MicrosoftDocs/azure-docs/issues/8121#issuecomment-387090099
// In AAD App Registration configure oauth2AllowImplicitFlow to true
Expand All @@ -63,7 +66,19 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
aadAuthenticationProperties.getTenantId()
)
.accessToken();
mappedAuthorities = azureADGraphClient.getGrantedAuthorities(graphApiToken);
Set<String> groups = azureADGraphClient.getGroups(graphApiToken);
Set<String> groupRoles = groups.stream()
.filter(aadAuthenticationProperties::isAllowedGroup)
.map(group -> ROLE_PREFIX + group)
.collect(Collectors.toSet());
Set<String> allRoles = oidcUser.getAuthorities()
.stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
allRoles.addAll(groupRoles);
authorities = allRoles.stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
} catch (MalformedURLException e) {
throw toOAuth2AuthenticationException(INVALID_REQUEST, "Failed to acquire token for Graph API.", e);
} catch (ServiceUnavailableException e) {
Expand All @@ -85,10 +100,10 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(s -> !s.isEmpty())
.filter(StringUtils::hasText)
.orElse(AADTokenClaim.NAME);
// Create a copy of oidcUser but use the mappedAuthorities instead
return new DefaultOidcUser(mappedAuthorities, oidcUser.getIdToken(), nameAttributeKey);
return new DefaultOidcUser(authorities, oidcUser.getIdToken(), nameAttributeKey);
}

private OAuth2AuthenticationException toOAuth2AuthenticationException(String errorCode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,6 @@ private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue());
}

/**
* @param graphApiToken token of graph api.
* @return set of SimpleGrantedAuthority
* @throws IOException throw exception if get groups failed by IOException.
*/
public Set<SimpleGrantedAuthority> getGrantedAuthorities(String graphApiToken) throws IOException {
return toGrantedAuthoritySet(getGroups(graphApiToken));
}

public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) {
Set<SimpleGrantedAuthority> grantedAuthoritySet =
groups.stream()
Expand Down
Loading

0 comments on commit edf5460

Please sign in to comment.