Skip to content

Always return current ClientRegistration in loadAuthorizedClient #16133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,7 +80,13 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
if (registration == null) {
return null;
}
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
if (cachedAuthorizedClient == null) {
return null;
}
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,8 +62,19 @@ public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String cl
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
.mapNotNull((clientRegistration) -> {
OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
if (cachedAuthorizedClient == null) {
return null;
}
// @formatter:off
return new OAuth2AuthorizedClient(clientRegistration,
cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(),
cachedAuthorizedClient.getRefreshToken());
// @formatter:on
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,7 +33,7 @@
import static org.assertj.core.api.Assertions.assertThatObject;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.BDDMockito.mock;

/**
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
Expand Down Expand Up @@ -79,9 +79,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept
@Test
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
String registrationId = this.registration3.getRegistrationId();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
mock(OAuth2AccessToken.class));
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
mock(OAuth2AuthorizedClient.class));
authorizedClient);
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
Expand Down Expand Up @@ -124,7 +126,35 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
}

@Test
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
.clientSecret("updated secret")
.build();
ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
updatedRegistration);

Authentication authentication = mock(Authentication.class);
given(authentication.getName()).willReturn(this.principalName1);

InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
mock(OAuth2AccessToken.class));
service.saveAuthorizedClient(authorizedClient, authentication);

OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
this.principalName1, mock(OAuth2AccessToken.class));
OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
this.principalName1);
OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
this.principalName1);
assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
}

@Test
Expand All @@ -148,7 +178,7 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
}

@Test
Expand Down Expand Up @@ -180,4 +210,29 @@ public void removeAuthorizedClientWhenSavedThenRemoved() {
assertThat(loadedAuthorizedClient).isNull();
}

private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,12 +18,14 @@

import java.time.Duration;
import java.time.Instant;
import java.util.function.Consumer;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

Expand All @@ -35,6 +37,7 @@
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;

Expand Down Expand Up @@ -153,11 +156,37 @@ public void loadAuthorizedClientWhenClientRegistrationFoundThenFound() {
.saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoad)
.expectNext(authorizedClient)
.assertNext(isEqualTo(authorizedClient))
.verifyComplete();
// @formatter:on
}

@Test
@SuppressWarnings("unchecked")
public void loadAuthorizedClientWhenClientRegistrationChangedThenCurrentVersionFound() {
ClientRegistration changedClientRegistration = ClientRegistration
.withClientRegistration(this.clientRegistration)
.clientSecret("updated secret")
.build();

given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
.willReturn(Mono.just(this.clientRegistration), Mono.just(changedClientRegistration));
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken);
OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(
changedClientRegistration, this.principalName, this.accessToken);

Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
.saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
.concatWith(
this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoadTwice)
.assertNext(isEqualTo(authorizedClient))
.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
.verifyComplete();
}

@Test
public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() {
OAuth2AuthorizedClient authorizedClient = null;
Expand Down Expand Up @@ -246,4 +275,31 @@ public void removeAuthorizedClientWhenClientRegistrationFoundRemovedThenNotFound
// @formatter:on
}

private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
return (actual) -> {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
};
}

}