Skip to content

Commit

Permalink
Polish #7589
Browse files Browse the repository at this point in the history
Rename OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager to AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.

Handle empty mono returned from contextAttributesMapper.

Handle empty map returned from contextAttributesMapper.

Fix DefaultContextAttributesMapper so that it doesn't access ServerWebExchange.

Fix unit tests so that they pass.

Use StepVerifier in unit tests, rather than .subscribe().

Fixes gh-7569
  • Loading branch information
philsttr authored and jgrandja committed Dec 10, 2019
1 parent 4c5c4f6 commit 840d3aa
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,82 +15,88 @@
*/
package org.springframework.security.oauth2.client;

import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

/**
* An implementation of an {@link ReactiveOAuth2AuthorizedClientManager}
* that is capable of operating outside of a {@code ServerHttpRequest} context,
* e.g. in a scheduled/background thread and/or in the service-tier.
*
* <p>This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}</p>
*
* @author Ankur Pathak
* @author Phil Clay
* @see ReactiveOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
* @see ReactiveOAuth2AuthorizedClientService
* @since 5.3
* @since 5.2.2
*/
public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
implements ReactiveOAuth2AuthorizedClientManager {

private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();

/**
* Constructs an {@code OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
* Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientService the authorized client service
*/
public OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository,
public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ReactiveOAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService;
}

@Nullable
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");

return createAuthorizationContext(authorizeRequest)
.flatMap(this::authorizeAndSave);
}

private Mono<OAuth2AuthorizationContext> createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) {
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient();
Authentication principal = authorizeRequest.getPrincipal();
// @formatter:off
return Mono.justOrEmpty(authorizedClient)
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))
)
.switchIfEmpty(Mono.error(new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
)
)
.flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration))))
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")))))
.flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest)
.filter(contextAttributes-> !CollectionUtils.isEmpty(contextAttributes))
.map(contextAttributes -> contextBuilder.principal(principal)
.attributes(attributes -> {
attributes.putAll(contextAttributes);
}).build())
).flatMap(authorizationContext -> this.authorizedClientProvider.authorize(authorizationContext)
.doOnNext(_authorizedClient -> authorizedClientService.saveAuthorizedClient(_authorizedClient, principal))
.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(Optional.ofNullable(authorizationContext.getAuthorizedClient()))))
);
// @formatter:on
.defaultIfEmpty(Collections.emptyMap())
.map(contextAttributes -> {
OAuth2AuthorizationContext.Builder builder = contextBuilder.principal(principal);
if (!contextAttributes.isEmpty()) {
builder = builder.attributes(attributes -> attributes.putAll(contextAttributes));
}
return builder.build();
}));
}

private Mono<OAuth2AuthorizedClient> authorizeAndSave(OAuth2AuthorizationContext authorizationContext) {
return this.authorizedClientProvider.authorize(authorizationContext)
.flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient(
authorizedClient,
authorizationContext.getPrincipal())
.thenReturn(authorizedClient))
.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient())));
}

/**
Expand All @@ -115,33 +121,17 @@ public void setContextAttributesMapper(Function<OAuth2AuthorizeRequest, Mono<Map
this.contextAttributesMapper = contextAttributesMapper;
}

private static Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}

/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/
public static class DefaultContextAttributesMapper implements Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> {

private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper =
new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper();

@Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
})
.defaultIfEmpty(Collections.emptyMap());
return Mono.fromCallable(() -> mapper.apply(authorizeRequest));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,49 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;

import java.util.Map;
import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* Tests for {@link OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}.
* Tests for {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}.
*
* @author Ankur Pathak
* @author Phil Clay
*/
public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager;
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper;
private AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
private PublisherProbe<Void> saveAuthorizedClientProbe;

@SuppressWarnings("unchecked")
@Before
public void setup() {
this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class);
this.saveAuthorizedClientProbe = PublisherProbe.empty();
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono());
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizedClientManager = new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty());
this.authorizedClientManager = new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientService);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
Expand All @@ -73,23 +83,23 @@ public void setup() {

@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService))
assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveClientRegistrationRepository cannot be null");
.hasMessage("clientRegistrationRepository cannot be null");
}

@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientService cannot be null");
.hasMessage("authorizedClientService cannot be null");
}

@Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientProvider cannot be null");
.hasMessage("authorizedClientProvider cannot be null");
}

@Test
Expand Down Expand Up @@ -132,7 +142,7 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);

authorizedClient.subscribe();
StepVerifier.create(authorizedClient).verifyComplete();

verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
Expand All @@ -142,7 +152,6 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);

StepVerifier.create(authorizedClient).expectComplete();
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
Expand All @@ -163,7 +172,9 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);

authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(this.authorizedClient)
.verifyComplete();

verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
Expand All @@ -173,9 +184,9 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);

StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}

@SuppressWarnings("unchecked")
Expand All @@ -197,8 +208,9 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);

authorizedClient.subscribe();

StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));

Expand All @@ -207,9 +219,9 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);

StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}

@SuppressWarnings("unchecked")
Expand All @@ -221,8 +233,9 @@ public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);

authorizedClient.subscribe();

StepVerifier.create(authorizedClient)
.expectNext(this.authorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));

Expand All @@ -231,7 +244,6 @@ public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);

StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
Expand All @@ -250,7 +262,9 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() {
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);

authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();

verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
Expand All @@ -260,9 +274,9 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() {
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);

StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}

@SuppressWarnings("unchecked")
Expand All @@ -274,14 +288,20 @@ public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() {

when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));


OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2ParameterNames.SCOPE, "read write")
.build();

this.authorizedClientManager.setContextAttributesMapper(new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);

authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();

verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());

Expand All @@ -293,8 +313,5 @@ public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() {
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");

StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
}

0 comments on commit 840d3aa

Please sign in to comment.