Skip to content

Commit

Permalink
Use Oauth2AuthorizedClientProvider to implement authorize clients del…
Browse files Browse the repository at this point in the history
…egated by azure client. (#23366)
  • Loading branch information
Rujun Chen authored Aug 9, 2021
1 parent 5d1a6b7 commit a3507ca
Show file tree
Hide file tree
Showing 26 changed files with 1,357 additions and 1,401 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import org.springframework.security.oauth2.core.AuthorizationGrantType;

/**
* Defines grant types: client_credentials, authorization_code, on_behalf_of.
* Defines grant types: client_credentials, authorization_code, on_behalf_of, azure_delegated.
*/
public enum AADAuthorizationGrantType {

CLIENT_CREDENTIALS("client_credentials"),
AUTHORIZATION_CODE("authorization_code"),
ON_BEHALF_OF("on_behalf_of");
ON_BEHALF_OF("on_behalf_of"),
AZURE_DELEGATED("azure_delegated");

private final String authorizationGrantType;

Expand Down

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
package com.azure.spring.aad;

import com.azure.spring.aad.webapi.AADOBOOAuth2AuthorizedClientProvider;
import com.azure.spring.aad.webapp.AADAzureDelegatedOAuth2AuthorizedClientProvider;
import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties;
import com.azure.spring.autoconfigure.condition.aad.ClientRegistrationCondition;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
Expand All @@ -28,35 +29,37 @@
@Conditional(ClientRegistrationCondition.class)
public class AADOAuth2ClientConfiguration {

@Autowired
private AADAuthenticationProperties properties;

@Bean
@ConditionalOnMissingBean
public AADClientRegistrationRepository clientRegistrationRepository() {
public AADClientRegistrationRepository clientRegistrationRepository(AADAuthenticationProperties properties) {
return new AADClientRegistrationRepository(properties);
}

@Bean
@ConditionalOnMissingBean
public AADOAuth2AuthorizedClientRepository authorizedClientRepository(AADClientRegistrationRepository repo) {
return new AADOAuth2AuthorizedClientRepository(repo);
public OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() {
return new JacksonHttpSessionOAuth2AuthorizedClientRepository();
}

@Bean
@ConditionalOnMissingBean
public OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository repository,
public OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository clientRegistrations,
OAuth2AuthorizedClientRepository authorizedClients) {

DefaultOAuth2AuthorizedClientManager manager =
new DefaultOAuth2AuthorizedClientManager(repository, authorizedClients);
new DefaultOAuth2AuthorizedClientManager(clientRegistrations, authorizedClients);
AADAzureDelegatedOAuth2AuthorizedClientProvider azureDelegatedProvider =
new AADAzureDelegatedOAuth2AuthorizedClientProvider(
new RefreshTokenOAuth2AuthorizedClientProvider(),
authorizedClients);
AADOBOOAuth2AuthorizedClientProvider oboProvider = new AADOBOOAuth2AuthorizedClientProvider();
OAuth2AuthorizedClientProvider authorizedClientProviders =
OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.password()
.provider(new AADOBOOAuth2AuthorizedClientProvider())
.provider(azureDelegatedProvider)
.provider(oboProvider)
.build();
manager.setAuthorizedClientProvider(authorizedClientProviders);
return manager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
public abstract class AbstractOAuth2AuthorizationCodeGrantRequestEntityConverter
extends OAuth2AuthorizationCodeGrantRequestEntityConverter {

protected String azureModule;
protected abstract String getApplicationId();

@Override
@SuppressWarnings("unchecked")
Expand All @@ -36,7 +36,7 @@ public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest request) {
.ifPresent(headers -> headers.forEach(httpHeaders::put));
MultiValueMap<String, String> body = (MultiValueMap<String, String>) requestEntity.getBody();
Assert.notNull(body, "body can not be null");
Optional.ofNullable(getHttpBody(request)).ifPresent(ext -> body.putAll(ext));
Optional.ofNullable(getHttpBody(request)).ifPresent(body::putAll);
return new RequestEntity<>(body, httpHeaders, requestEntity.getMethod(), requestEntity.getUrl());
}

Expand All @@ -46,7 +46,7 @@ public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest request) {
*/
public HttpHeaders getHttpHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.put("x-client-SKU", Collections.singletonList(azureModule));
httpHeaders.put("x-client-SKU", Collections.singletonList(getApplicationId()));
httpHeaders.put("x-client-VER", Collections.singletonList(ApplicationId.VERSION));
httpHeaders.put("client-request-id", Collections.singletonList(UUID.randomUUID().toString()));
return httpHeaders;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@

package com.azure.spring.aad;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.springframework.security.core.Authentication;
import org.springframework.security.jackson2.CoreJackson2Module;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;

Expand All @@ -21,6 +15,9 @@
import java.util.Map;
import java.util.Optional;

import static com.azure.spring.aad.implementation.jackson.SerializerUtils.deserializeOAuth2AuthorizedClientMap;
import static com.azure.spring.aad.implementation.jackson.SerializerUtils.serializeOAuth2AuthorizedClientMap;

/**
* An implementation of an {@link OAuth2AuthorizedClientRepository} that stores {@link OAuth2AuthorizedClient}'s in the
* {@code HttpSession}. To make it compatible with different spring versions. Refs:
Expand All @@ -32,17 +29,6 @@
public class JacksonHttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
private static final String AUTHORIZED_CLIENTS_ATTR_NAME =
JacksonHttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS";
private final ObjectMapper objectMapper;
private static final TypeReference<Map<String, OAuth2AuthorizedClient>> TYPE_REFERENCE =
new TypeReference<Map<String, OAuth2AuthorizedClient>>() {
};

public JacksonHttpSessionOAuth2AuthorizedClientRepository() {
objectMapper = new ObjectMapper();
objectMapper.registerModule(new OAuth2ClientJackson2Module());
objectMapper.registerModule(new CoreJackson2Module());
objectMapper.registerModule(new JavaTimeModule());
}

@SuppressWarnings("unchecked")
@Override
Expand All @@ -62,7 +48,8 @@ public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authen
Assert.notNull(response, "response cannot be null");
Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient);
request.getSession().setAttribute(AUTHORIZED_CLIENTS_ATTR_NAME, toString(authorizedClients));
request.getSession().setAttribute(AUTHORIZED_CLIENTS_ATTR_NAME,
serializeOAuth2AuthorizedClientMap(authorizedClients));
}

@Override
Expand All @@ -74,24 +61,15 @@ public void removeAuthorizedClient(String clientRegistrationId, Authentication p
if (!authorizedClients.isEmpty()) {
if (authorizedClients.remove(clientRegistrationId) != null) {
if (!authorizedClients.isEmpty()) {
request.getSession().setAttribute(AUTHORIZED_CLIENTS_ATTR_NAME, toString(authorizedClients));
request.getSession().setAttribute(AUTHORIZED_CLIENTS_ATTR_NAME,
serializeOAuth2AuthorizedClientMap(authorizedClients));
} else {
request.getSession().removeAttribute(AUTHORIZED_CLIENTS_ATTR_NAME);
}
}
}
}

private String toString(Map<String, OAuth2AuthorizedClient> authorizedClients) {
String result;
try {
result = objectMapper.writeValueAsString(authorizedClients);
} catch (JsonProcessingException e) {
throw new IllegalStateException(e);
}
return result;
}

private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(HttpServletRequest request) {
HttpSession session = request.getSession(false);
String authorizedClientsString = (String) Optional.ofNullable(session)
Expand All @@ -100,12 +78,6 @@ private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(HttpServletRequ
if (authorizedClientsString == null) {
return new HashMap<>();
}
Map<String, OAuth2AuthorizedClient> authorizedClients;
try {
authorizedClients = objectMapper.readValue(authorizedClientsString, TYPE_REFERENCE);
} catch (JsonProcessingException e) {
throw new IllegalStateException(e);
}
return authorizedClients;
return deserializeOAuth2AuthorizedClientMap(authorizedClientsString);
}
}
Loading

0 comments on commit a3507ca

Please sign in to comment.