Skip to content

Commit ae22729

Browse files
committed
Always return current ClientRegistration in loadAuthorizedClient
This changes `InMemoryOAuth2AuthorizedClientService.loadAuthorizedClient` (and its reactive counterpart) to always return `OAuth2AuthorizedClient` instances containing the current `ClientRegistration` as obtained from the `ClientRegistrationRepository`. Before this change, the first `ClientRegistration` instance was cached, with the effect that any changes made in the `ClientRegistrationRepository` (such as a new client secret) would not have taken effect. Closes gh-15511
1 parent 30c9860 commit ae22729

File tree

4 files changed

+165
-42
lines changed

4 files changed

+165
-42
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -80,7 +80,13 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
8080
if (registration == null) {
8181
return null;
8282
}
83-
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
83+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
84+
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
85+
if (cachedAuthorizedClient == null) {
86+
return null;
87+
}
88+
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
89+
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
8490
}
8591

8692
@Override

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -62,8 +62,19 @@ public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String cl
6262
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
6363
Assert.hasText(principalName, "principalName cannot be empty");
6464
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
65-
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
66-
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
65+
.mapNotNull((clientRegistration) -> {
66+
OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
67+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
68+
if (cachedAuthorizedClient == null) {
69+
return null;
70+
}
71+
// @formatter:off
72+
return new OAuth2AuthorizedClient(clientRegistration,
73+
cachedAuthorizedClient.getPrincipalName(),
74+
cachedAuthorizedClient.getAccessToken(),
75+
cachedAuthorizedClient.getRefreshToken());
76+
// @formatter:on
77+
});
6778
}
6879

6980
@Override

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,12 +28,9 @@
2828
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
2929
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3030

31-
import static org.assertj.core.api.Assertions.assertThat;
32-
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
33-
import static org.assertj.core.api.Assertions.assertThatObject;
34-
import static org.mockito.ArgumentMatchers.eq;
35-
import static org.mockito.BDDMockito.given;
36-
import static org.mockito.Mockito.mock;
31+
import static org.assertj.core.api.Assertions.*;
32+
import static org.mockito.ArgumentMatchers.*;
33+
import static org.mockito.BDDMockito.*;
3734

3835
/**
3936
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
@@ -52,9 +49,9 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
5249
private ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build();
5350

5451
private ClientRegistration registration3 = TestClientRegistrations.clientRegistration()
55-
.clientId("client-3")
56-
.registrationId("registration-3")
57-
.build();
52+
.clientId("client-3")
53+
.registrationId("registration-3")
54+
.build();
5855

5956
private ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(
6057
this.registration1, this.registration2, this.registration3);
@@ -79,9 +76,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept
7976
@Test
8077
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
8178
String registrationId = this.registration3.getRegistrationId();
79+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
80+
mock(OAuth2AccessToken.class));
8281
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
8382
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
84-
mock(OAuth2AuthorizedClient.class));
83+
authorizedClient);
8584
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
8685
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
8786
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
@@ -92,7 +91,7 @@ public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedCli
9291
@Test
9392
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
9493
assertThatIllegalArgumentException()
95-
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, this.principalName1));
94+
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, this.principalName1));
9695
}
9796

9897
@Test
@@ -104,14 +103,14 @@ public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentE
104103
@Test
105104
public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
106105
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
107-
.loadAuthorizedClient("registration-not-found", this.principalName1);
106+
.loadAuthorizedClient("registration-not-found", this.principalName1);
108107
assertThat(authorizedClient).isNull();
109108
}
110109

111110
@Test
112111
public void loadAuthorizedClientWhenClientRegistrationFoundButNotAssociatedToPrincipalThenReturnNull() {
113112
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
114-
.loadAuthorizedClient(this.registration1.getRegistrationId(), "principal-not-found");
113+
.loadAuthorizedClient(this.registration1.getRegistrationId(), "principal-not-found");
115114
assertThat(authorizedClient).isNull();
116115
}
117116

@@ -123,14 +122,42 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci
123122
mock(OAuth2AccessToken.class));
124123
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
125124
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
126-
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
127-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
125+
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
126+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
127+
}
128+
129+
@Test
130+
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
131+
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
132+
.clientSecret("updated secret")
133+
.build();
134+
ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
135+
given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
136+
updatedRegistration);
137+
138+
Authentication authentication = mock(Authentication.class);
139+
given(authentication.getName()).willReturn(this.principalName1);
140+
141+
InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);
142+
143+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
144+
mock(OAuth2AccessToken.class));
145+
service.saveAuthorizedClient(authorizedClient, authentication);
146+
147+
OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
148+
this.principalName1, mock(OAuth2AccessToken.class));
149+
OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
150+
this.principalName1);
151+
OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
152+
this.principalName1);
153+
assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
154+
assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
128155
}
129156

130157
@Test
131158
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
132159
assertThatIllegalArgumentException()
133-
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class)));
160+
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class)));
134161
}
135162

136163
@Test
@@ -147,20 +174,20 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() {
147174
mock(OAuth2AccessToken.class));
148175
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
149176
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
150-
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
151-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
177+
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
178+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
152179
}
153180

154181
@Test
155182
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
156183
assertThatIllegalArgumentException()
157-
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, this.principalName2));
184+
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, this.principalName2));
158185
}
159186

160187
@Test
161188
public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
162189
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientService
163-
.removeAuthorizedClient(this.registration3.getRegistrationId(), null));
190+
.removeAuthorizedClient(this.registration3.getRegistrationId(), null));
164191
}
165192

166193
@Test
@@ -171,13 +198,38 @@ public void removeAuthorizedClientWhenSavedThenRemoved() {
171198
mock(OAuth2AccessToken.class));
172199
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
173200
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
174-
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
201+
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
175202
assertThat(loadedAuthorizedClient).isNotNull();
176203
this.authorizedClientService.removeAuthorizedClient(this.registration2.getRegistrationId(),
177204
this.principalName2);
178205
loadedAuthorizedClient = this.authorizedClientService
179-
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
206+
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
180207
assertThat(loadedAuthorizedClient).isNull();
181208
}
182209

210+
private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
211+
assertThat(actual).isNotNull();
212+
assertThat(actual.getClientRegistration().getRegistrationId())
213+
.isEqualTo(expected.getClientRegistration().getRegistrationId());
214+
assertThat(actual.getClientRegistration().getClientName())
215+
.isEqualTo(expected.getClientRegistration().getClientName());
216+
assertThat(actual.getClientRegistration().getRedirectUri())
217+
.isEqualTo(expected.getClientRegistration().getRedirectUri());
218+
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
219+
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
220+
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
221+
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
222+
assertThat(actual.getClientRegistration().getClientId())
223+
.isEqualTo(expected.getClientRegistration().getClientId());
224+
assertThat(actual.getClientRegistration().getClientSecret())
225+
.isEqualTo(expected.getClientRegistration().getClientSecret());
226+
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
227+
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
228+
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
229+
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
230+
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
231+
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
232+
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
233+
}
234+
183235
}

0 commit comments

Comments
 (0)