17
17
package org .springframework .security .oauth2 .client .web .function .client ;
18
18
19
19
import java .io .IOException ;
20
- import java .net .URI ;
21
20
import java .util .HashMap ;
22
21
import java .util .Map ;
23
22
import java .util .function .Consumer ;
26
25
import jakarta .servlet .http .HttpServletResponse ;
27
26
28
27
import org .springframework .http .HttpHeaders ;
29
- import org .springframework .http .HttpMethod ;
30
28
import org .springframework .http .HttpRequest ;
31
29
import org .springframework .http .HttpStatus ;
32
30
import org .springframework .http .HttpStatusCode ;
47
45
import org .springframework .security .oauth2 .client .OAuth2AuthorizedClientProvider ;
48
46
import org .springframework .security .oauth2 .client .OAuth2AuthorizedClientService ;
49
47
import org .springframework .security .oauth2 .client .RemoveAuthorizedClientOAuth2AuthorizationFailureHandler ;
48
+ import org .springframework .security .oauth2 .client .authentication .OAuth2AuthenticationToken ;
50
49
import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
51
50
import org .springframework .security .oauth2 .client .web .OAuth2AuthorizedClientRepository ;
52
51
import org .springframework .security .oauth2 .core .OAuth2AuthorizationException ;
55
54
import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
56
55
import org .springframework .util .Assert ;
57
56
import org .springframework .util .StringUtils ;
58
- import org .springframework .web .client .DefaultResponseErrorHandler ;
59
- import org .springframework .web .client .ResponseErrorHandler ;
57
+ import org .springframework .web .client .RestClient ;
60
58
import org .springframework .web .client .RestClientResponseException ;
61
59
import org .springframework .web .context .request .RequestContextHolder ;
62
60
import org .springframework .web .context .request .ServletRequestAttributes ;
@@ -116,9 +114,12 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
116
114
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken ("anonymous" ,
117
115
"anonymousUser" , AuthorityUtils .createAuthorityList ("ROLE_ANONYMOUS" ));
118
116
117
+ private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2ClientHttpRequestInterceptor .class .getName ()
118
+ .concat (".clientRegistrationId" );
119
+
119
120
private final OAuth2AuthorizedClientManager authorizedClientManager ;
120
121
121
- private final String clientRegistrationId ;
122
+ private String defaultClientRegistrationId ;
122
123
123
124
// @formatter:off
124
125
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
@@ -133,15 +134,27 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
133
134
* parameters.
134
135
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
135
136
* manages the authorized client(s)
136
- * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
137
- * be used to look up the {@link OAuth2AuthorizedClient}
138
137
*/
139
- public OAuth2ClientHttpRequestInterceptor (OAuth2AuthorizedClientManager authorizedClientManager ,
140
- String clientRegistrationId ) {
138
+ public OAuth2ClientHttpRequestInterceptor (OAuth2AuthorizedClientManager authorizedClientManager ) {
141
139
Assert .notNull (authorizedClientManager , "authorizedClientManager cannot be null" );
142
- Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
143
140
this .authorizedClientManager = authorizedClientManager ;
144
- this .clientRegistrationId = clientRegistrationId ;
141
+ }
142
+
143
+ /**
144
+ * Sets the default {@code clientRegistrationId} to be used for resolving an
145
+ * {@link OAuth2AuthorizedClient}.
146
+ *
147
+ * <p>
148
+ * By default, the {@code clientRegistrationId} is obtained from the current
149
+ * {@link Authentication principal}. Using this setter overrides the default, but can
150
+ * be overridden by providing an
151
+ * {@link RestClient.RequestHeadersSpec#attributes(Consumer) attribute} via
152
+ * {@link #clientRegistrationId(String)}.
153
+ * @param clientRegistrationId the default {@code clientRegistrationId}
154
+ */
155
+ public void setDefaultClientRegistrationId (String clientRegistrationId ) {
156
+ Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
157
+ this .defaultClientRegistrationId = clientRegistrationId ;
145
158
}
146
159
147
160
/**
@@ -237,33 +250,52 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
237
250
this .securityContextHolderStrategy = securityContextHolderStrategy ;
238
251
}
239
252
253
+ /**
254
+ * Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
255
+ * {@link ClientRegistration#getRegistrationId() clientRegistrationId} to be used to
256
+ * look up the {@link OAuth2AuthorizedClient}.
257
+ * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
258
+ * clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
259
+ * @return the {@link Consumer} to populate the attributes
260
+ */
261
+ public static Consumer <Map <String , Object >> clientRegistrationId (String clientRegistrationId ) {
262
+ Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
263
+ return (attributes ) -> attributes .put (CLIENT_REGISTRATION_ID_ATTR_NAME , clientRegistrationId );
264
+ }
265
+
240
266
@ Override
241
267
public ClientHttpResponse intercept (HttpRequest request , byte [] body , ClientHttpRequestExecution execution )
242
268
throws IOException {
243
- authorizeClient (request );
269
+ Authentication principal = this .securityContextHolderStrategy .getContext ().getAuthentication ();
270
+ if (principal == null ) {
271
+ principal = ANONYMOUS_AUTHENTICATION ;
272
+ }
273
+
274
+ authorizeClient (request , principal );
244
275
try {
245
276
ClientHttpResponse response = execution .execute (request , body );
246
- handleAuthorizationFailure (response .getHeaders (), response .getStatusCode ());
277
+ handleAuthorizationFailure (request , principal , response .getHeaders (), response .getStatusCode ());
247
278
return response ;
248
279
}
249
280
catch (RestClientResponseException ex ) {
250
- handleAuthorizationFailure (ex .getResponseHeaders (), ex .getStatusCode ());
281
+ handleAuthorizationFailure (request , principal , ex .getResponseHeaders (), ex .getStatusCode ());
251
282
throw ex ;
252
283
}
253
284
catch (OAuth2AuthorizationException ex ) {
254
- handleAuthorizationFailure (ex );
285
+ handleAuthorizationFailure (ex , principal );
255
286
throw ex ;
256
287
}
257
288
}
258
289
259
- private void authorizeClient (HttpRequest request ) {
260
- Authentication principal = this . securityContextHolderStrategy . getContext (). getAuthentication ( );
261
- if (principal == null ) {
262
- principal = ANONYMOUS_AUTHENTICATION ;
290
+ private void authorizeClient (HttpRequest request , Authentication principal ) {
291
+ String clientRegistrationId = clientRegistrationId ( request , principal );
292
+ if (clientRegistrationId == null ) {
293
+ return ;
263
294
}
295
+
264
296
// @formatter:off
265
297
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
266
- .withClientRegistrationId (this . clientRegistrationId )
298
+ .withClientRegistrationId (clientRegistrationId )
267
299
.principal (principal )
268
300
.build ();
269
301
// @formatter:on
@@ -273,15 +305,21 @@ private void authorizeClient(HttpRequest request) {
273
305
}
274
306
}
275
307
276
- private void handleAuthorizationFailure (HttpHeaders headers , HttpStatusCode httpStatus ) {
308
+ private void handleAuthorizationFailure (HttpRequest request , Authentication principal , HttpHeaders headers ,
309
+ HttpStatusCode httpStatus ) {
277
310
OAuth2Error error = resolveOAuth2ErrorIfPossible (headers , httpStatus );
278
311
if (error == null ) {
279
312
return ;
280
313
}
281
314
315
+ String clientRegistrationId = clientRegistrationId (request , principal );
316
+ if (clientRegistrationId == null ) {
317
+ return ;
318
+ }
319
+
282
320
ClientAuthorizationException authorizationException = new ClientAuthorizationException (error ,
283
- this . clientRegistrationId );
284
- handleAuthorizationFailure (authorizationException );
321
+ clientRegistrationId );
322
+ handleAuthorizationFailure (authorizationException , principal );
285
323
}
286
324
287
325
private static OAuth2Error resolveOAuth2ErrorIfPossible (HttpHeaders headers , HttpStatusCode httpStatus ) {
@@ -323,12 +361,20 @@ private static Map<String, String> parseWwwAuthenticateHeader(String wwwAuthenti
323
361
return parameters ;
324
362
}
325
363
326
- private void handleAuthorizationFailure (OAuth2AuthorizationException authorizationException ) {
327
- Authentication principal = this .securityContextHolderStrategy .getContext ().getAuthentication ();
328
- if (principal == null ) {
329
- principal = ANONYMOUS_AUTHENTICATION ;
364
+ private String clientRegistrationId (HttpRequest request , Authentication principal ) {
365
+ String clientRegistrationId = (String ) request .getAttributes ().get (CLIENT_REGISTRATION_ID_ATTR_NAME );
366
+ if (clientRegistrationId == null ) {
367
+ clientRegistrationId = this .defaultClientRegistrationId ;
368
+ }
369
+ if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken authentication ) {
370
+ clientRegistrationId = authentication .getAuthorizedClientRegistrationId ();
330
371
}
331
372
373
+ return clientRegistrationId ;
374
+ }
375
+
376
+ private void handleAuthorizationFailure (OAuth2AuthorizationException authorizationException ,
377
+ Authentication principal ) {
332
378
ServletRequestAttributes requestAttributes = (ServletRequestAttributes ) RequestContextHolder
333
379
.getRequestAttributes ();
334
380
Map <String , Object > attributes = new HashMap <>();
0 commit comments