Skip to content

Commit a3d61bd

Browse files
committed
Add support for clientRegistrationId attribute
1 parent 3feee0d commit a3d61bd

File tree

2 files changed

+216
-91
lines changed

2 files changed

+216
-91
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/function/client/OAuth2ClientHttpRequestInterceptor.java

+73-27
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.security.oauth2.client.web.function.client;
1818

1919
import java.io.IOException;
20-
import java.net.URI;
2120
import java.util.HashMap;
2221
import java.util.Map;
2322
import java.util.function.Consumer;
@@ -26,7 +25,6 @@
2625
import jakarta.servlet.http.HttpServletResponse;
2726

2827
import org.springframework.http.HttpHeaders;
29-
import org.springframework.http.HttpMethod;
3028
import org.springframework.http.HttpRequest;
3129
import org.springframework.http.HttpStatus;
3230
import org.springframework.http.HttpStatusCode;
@@ -47,6 +45,7 @@
4745
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
4846
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
4947
import org.springframework.security.oauth2.client.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
48+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
5049
import org.springframework.security.oauth2.client.registration.ClientRegistration;
5150
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
5251
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
@@ -55,8 +54,7 @@
5554
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
5655
import org.springframework.util.Assert;
5756
import org.springframework.util.StringUtils;
58-
import org.springframework.web.client.DefaultResponseErrorHandler;
59-
import org.springframework.web.client.ResponseErrorHandler;
57+
import org.springframework.web.client.RestClient;
6058
import org.springframework.web.client.RestClientResponseException;
6159
import org.springframework.web.context.request.RequestContextHolder;
6260
import org.springframework.web.context.request.ServletRequestAttributes;
@@ -116,9 +114,12 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
116114
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
117115
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
118116

117+
private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2ClientHttpRequestInterceptor.class.getName()
118+
.concat(".clientRegistrationId");
119+
119120
private final OAuth2AuthorizedClientManager authorizedClientManager;
120121

121-
private final String clientRegistrationId;
122+
private String defaultClientRegistrationId;
122123

123124
// @formatter:off
124125
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
@@ -133,15 +134,27 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
133134
* parameters.
134135
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
135136
* manages the authorized client(s)
136-
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
137-
* be used to look up the {@link OAuth2AuthorizedClient}
138137
*/
139-
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager,
140-
String clientRegistrationId) {
138+
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) {
141139
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
142-
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
143140
this.authorizedClientManager = authorizedClientManager;
144-
this.clientRegistrationId = clientRegistrationId;
141+
}
142+
143+
/**
144+
* Sets the default {@code clientRegistrationId} to be used for resolving an
145+
* {@link OAuth2AuthorizedClient}.
146+
*
147+
* <p>
148+
* By default, the {@code clientRegistrationId} is obtained from the current
149+
* {@link Authentication principal}. Using this setter overrides the default, but can
150+
* be overridden by providing an
151+
* {@link RestClient.RequestHeadersSpec#attributes(Consumer) attribute} via
152+
* {@link #clientRegistrationId(String)}.
153+
* @param clientRegistrationId the default {@code clientRegistrationId}
154+
*/
155+
public void setDefaultClientRegistrationId(String clientRegistrationId) {
156+
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
157+
this.defaultClientRegistrationId = clientRegistrationId;
145158
}
146159

147160
/**
@@ -237,33 +250,52 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
237250
this.securityContextHolderStrategy = securityContextHolderStrategy;
238251
}
239252

253+
/**
254+
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
255+
* {@link ClientRegistration#getRegistrationId() clientRegistrationId} to be used to
256+
* look up the {@link OAuth2AuthorizedClient}.
257+
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
258+
* clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
259+
* @return the {@link Consumer} to populate the attributes
260+
*/
261+
public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
262+
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
263+
return (attributes) -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
264+
}
265+
240266
@Override
241267
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
242268
throws IOException {
243-
authorizeClient(request);
269+
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
270+
if (principal == null) {
271+
principal = ANONYMOUS_AUTHENTICATION;
272+
}
273+
274+
authorizeClient(request, principal);
244275
try {
245276
ClientHttpResponse response = execution.execute(request, body);
246-
handleAuthorizationFailure(response.getHeaders(), response.getStatusCode());
277+
handleAuthorizationFailure(request, principal, response.getHeaders(), response.getStatusCode());
247278
return response;
248279
}
249280
catch (RestClientResponseException ex) {
250-
handleAuthorizationFailure(ex.getResponseHeaders(), ex.getStatusCode());
281+
handleAuthorizationFailure(request, principal, ex.getResponseHeaders(), ex.getStatusCode());
251282
throw ex;
252283
}
253284
catch (OAuth2AuthorizationException ex) {
254-
handleAuthorizationFailure(ex);
285+
handleAuthorizationFailure(ex, principal);
255286
throw ex;
256287
}
257288
}
258289

259-
private void authorizeClient(HttpRequest request) {
260-
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
261-
if (principal == null) {
262-
principal = ANONYMOUS_AUTHENTICATION;
290+
private void authorizeClient(HttpRequest request, Authentication principal) {
291+
String clientRegistrationId = clientRegistrationId(request, principal);
292+
if (clientRegistrationId == null) {
293+
return;
263294
}
295+
264296
// @formatter:off
265297
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
266-
.withClientRegistrationId(this.clientRegistrationId)
298+
.withClientRegistrationId(clientRegistrationId)
267299
.principal(principal)
268300
.build();
269301
// @formatter:on
@@ -273,15 +305,21 @@ private void authorizeClient(HttpRequest request) {
273305
}
274306
}
275307

276-
private void handleAuthorizationFailure(HttpHeaders headers, HttpStatusCode httpStatus) {
308+
private void handleAuthorizationFailure(HttpRequest request, Authentication principal, HttpHeaders headers,
309+
HttpStatusCode httpStatus) {
277310
OAuth2Error error = resolveOAuth2ErrorIfPossible(headers, httpStatus);
278311
if (error == null) {
279312
return;
280313
}
281314

315+
String clientRegistrationId = clientRegistrationId(request, principal);
316+
if (clientRegistrationId == null) {
317+
return;
318+
}
319+
282320
ClientAuthorizationException authorizationException = new ClientAuthorizationException(error,
283-
this.clientRegistrationId);
284-
handleAuthorizationFailure(authorizationException);
321+
clientRegistrationId);
322+
handleAuthorizationFailure(authorizationException, principal);
285323
}
286324

287325
private static OAuth2Error resolveOAuth2ErrorIfPossible(HttpHeaders headers, HttpStatusCode httpStatus) {
@@ -323,12 +361,20 @@ private static Map<String, String> parseWwwAuthenticateHeader(String wwwAuthenti
323361
return parameters;
324362
}
325363

326-
private void handleAuthorizationFailure(OAuth2AuthorizationException authorizationException) {
327-
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
328-
if (principal == null) {
329-
principal = ANONYMOUS_AUTHENTICATION;
364+
private String clientRegistrationId(HttpRequest request, Authentication principal) {
365+
String clientRegistrationId = (String) request.getAttributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
366+
if (clientRegistrationId == null) {
367+
clientRegistrationId = this.defaultClientRegistrationId;
368+
}
369+
if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken authentication) {
370+
clientRegistrationId = authentication.getAuthorizedClientRegistrationId();
330371
}
331372

373+
return clientRegistrationId;
374+
}
375+
376+
private void handleAuthorizationFailure(OAuth2AuthorizationException authorizationException,
377+
Authentication principal) {
332378
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder
333379
.getRequestAttributes();
334380
Map<String, Object> attributes = new HashMap<>();

0 commit comments

Comments
 (0)