Skip to content

Commit

Permalink
Add support configuring OAuth2AuthorizationRequestResolver as bean
Browse files Browse the repository at this point in the history
Closes gh-15236
  • Loading branch information
Max Batischev authored and sjohnr committed Jun 13, 2024
1 parent 60a6b38 commit 4e52eda
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,7 +58,7 @@
* {@link ClientRegistrationRepository} {@code @Bean} may be registered instead.
*
* <h2>Security Filters</h2>
*
* <p>
* The following {@code Filter}'s are populated for {@link #authorizationCodeGrant()}:
*
* <ul>
Expand All @@ -67,7 +67,7 @@
* </ul>
*
* <h2>Shared Objects Created</h2>
*
* <p>
* The following shared objects are populated:
*
* <ul>
Expand All @@ -76,7 +76,7 @@
* </ul>
*
* <h2>Shared Objects Used</h2>
*
* <p>
* The following shared objects are used:
*
* <ul>
Expand Down Expand Up @@ -283,10 +283,12 @@ private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
if (this.authorizationRequestResolver != null) {
return this.authorizationRequestResolver;
}
ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
.getClientRegistrationRepository(getBuilder());
return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository,
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
ResolvableType resolvableType = ResolvableType.forClass(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequestResolver bean = getBeanOrNull(resolvableType);
return (bean != null) ? bean
: new DefaultOAuth2AuthorizationRequestResolver(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(getBuilder()),
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
}

private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4532,9 +4532,12 @@ private ReactiveClientRegistrationRepository getClientRegistrationRepository() {
}

private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() {
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter;
if (this.authorizationRequestResolver != null) {
return new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver);
ServerOAuth2AuthorizationRequestResolver result = this.authorizationRequestResolver;
if (result == null) {
result = getBeanOrNull(ServerOAuth2AuthorizationRequestResolver.class);
}
if (result != null) {
return new OAuth2AuthorizationRequestRedirectWebFilter(result);
}
return new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -285,6 +285,18 @@ public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorization
verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString());
}

@Test
public void configureWhenCustomAuthorizationRequestResolverBeanPresentThenAuthorizationRequestIncludesCustomParameters()
throws Exception {
this.spring.register(OAuth2ClientBeanConfig.class).autowire();
// @formatter:off
this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection())
.andReturn();
// @formatter:on
verify(authorizationRequestResolver).resolve(any());
}

@EnableWebSecurity
@Configuration
@EnableWebMvc
Expand Down Expand Up @@ -362,4 +374,59 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() {

}

@EnableWebSecurity
@Configuration
@EnableWebMvc
static class OAuth2ClientBeanConfig {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests()
.anyRequest().authenticated()
.and()
.requestCache()
.requestCache(requestCache)
.and()
.oauth2Client()
.authorizationCodeGrant()
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient);
return http.build();
// @formatter:on
}

@Bean
ClientRegistrationRepository clientRegistrationRepository() {
return clientRegistrationRepository;
}

@Bean
OAuth2AuthorizedClientRepository authorizedClientRepository() {
return authorizedClientRepository;
}

@Bean
OAuth2AuthorizationRequestResolver authorizationRequestResolver() {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
given(authorizationRequestResolver.resolve(any()))
.willAnswer((invocation) -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
return authorizationRequestResolver;
}

@RestController
class ResourceController {

@GetMapping("/resource1")
String resource1(
@RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) {
return "resource1";
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -457,6 +458,7 @@ public void oauth2LoginWhenCustomBeansThenUsed() {
OidcUser user = TestOidcUsers.create();
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = config.userService;
given(userService.loadUser(any())).willReturn(Mono.just(user));
ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
// @formatter:off
webTestClient.get()
.uri("/login/oauth2/code/google")
Expand All @@ -466,6 +468,7 @@ public void oauth2LoginWhenCustomBeansThenUsed() {
verify(config.jwtDecoderFactory).createDecoder(any());
verify(tokenResponseClient).getTokenResponse(any());
verify(securityContextRepository).save(any(), any());
verify(resolver).resolve(any());
}

// gh-5562
Expand Down Expand Up @@ -837,6 +840,10 @@ static class OAuth2LoginWithCustomBeansConfig {

ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class);

ServerOAuth2AuthorizationRequestResolver resolver = spy(
new DefaultServerOAuth2AuthorizationRequestResolver(new InMemoryReactiveClientRegistrationRepository(
TestClientRegistrations.clientRegistration().build())));

@Bean
SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
// @formatter:off
Expand Down Expand Up @@ -864,6 +871,11 @@ ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return this.jwtDecoderFactory;
}

@Bean
ServerOAuth2AuthorizationRequestResolver resolver() {
return this.resolver;
}

@Bean
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient() {
return this.tokenResponseClient;
Expand Down

0 comments on commit 4e52eda

Please sign in to comment.