diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 0fc647de670..946b6dcd4e4 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,17 +18,40 @@ import java.util.List; +import org.springframework.beans.BeanMetadataElement; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.context.annotation.AnnotationBeanNameGenerator; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; +import org.springframework.core.ResolvableType; import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; @@ -48,7 +71,8 @@ * @since 5.1 * @see OAuth2ImportSelector */ -@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class) +@Import({ OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class, + OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration.class }) final class OAuth2ClientConfiguration { private static final boolean webMvcPresent; @@ -65,8 +89,22 @@ public String[] selectImports(AnnotationMetadata importingClassMetadata) { if (!webMvcPresent) { return new String[0]; } - return new String[] { "org.springframework.security.config.annotation.web.configuration." - + "OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" }; + return new String[] { + OAuth2ClientConfiguration.class.getName() + ".OAuth2ClientWebMvcSecurityConfiguration" }; + } + + } + + /** + * @author Joe Grandja + * @since 6.2.0 + */ + @Configuration(proxyBeanMethods = false) + static class OAuth2AuthorizedClientManagerConfiguration { + + @Bean + OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() { + return new OAuth2AuthorizedClientManagerRegistrar(); } } @@ -160,4 +198,136 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { } + /** + * A registrar for registering the default {@link OAuth2AuthorizedClientManager} bean + * definition, if not already present. + * + * @author Joe Grandja + * @since 6.2.0 + */ + static class OAuth2AuthorizedClientManagerRegistrar + implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + + private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + + private BeanFactory beanFactory; + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, true, true); + if (authorizedClientManagerBeanNames.length != 0) { + return; + } + + String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, true, true); + String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true); + if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) { + return; + } + + BeanDefinition beanDefinition = BeanDefinitionBuilder + .genericBeanDefinition(DefaultOAuth2AuthorizedClientManager.class) + .addConstructorArgReference(clientRegistrationRepositoryBeanNames[0]) + .addConstructorArgReference(authorizedClientRepositoryBeanNames[0]) + .addPropertyValue("authorizedClientProvider", getAuthorizedClientProvider()).getBeanDefinition(); + + registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry), + beanDefinition); + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + } + + private BeanDefinition getAuthorizedClientProvider() { + ManagedList authorizedClientProviders = new ManagedList<>(); + authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider()); + authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider()); + authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider()); + authorizedClientProviders.add(getPasswordAuthorizedClientProvider()); + return BeanDefinitionBuilder.genericBeanDefinition(DelegatingOAuth2AuthorizedClientProvider.class) + .addConstructorArgValue(authorizedClientProviders).getBeanDefinition(); + } + + private BeanMetadataElement getAuthorizationCodeAuthorizedClientProvider() { + String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, AuthorizationCodeOAuth2AuthorizedClientProvider.class, true, + true); + if (beanNames.length == 1) { + return new RuntimeBeanReference(beanNames[0]); + } + + return BeanDefinitionBuilder.genericBeanDefinition(AuthorizationCodeOAuth2AuthorizedClientProvider.class) + .getBeanDefinition(); + } + + private BeanMetadataElement getRefreshTokenAuthorizedClientProvider() { + String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, RefreshTokenOAuth2AuthorizedClientProvider.class, true, + true); + if (beanNames.length == 1) { + return new RuntimeBeanReference(beanNames[0]); + } + + BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder + .genericBeanDefinition(RefreshTokenOAuth2AuthorizedClientProvider.class); + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2RefreshTokenGrantRequest.class); + beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, + resolvableType, true, true); + if (beanNames.length == 1) { + beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + } + return beanDefinitionBuilder.getBeanDefinition(); + } + + private BeanMetadataElement getClientCredentialsAuthorizedClientProvider() { + String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, ClientCredentialsOAuth2AuthorizedClientProvider.class, true, + true); + if (beanNames.length == 1) { + return new RuntimeBeanReference(beanNames[0]); + } + + BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder + .genericBeanDefinition(ClientCredentialsOAuth2AuthorizedClientProvider.class); + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2ClientCredentialsGrantRequest.class); + beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, + resolvableType, true, true); + if (beanNames.length == 1) { + beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + } + return beanDefinitionBuilder.getBeanDefinition(); + } + + private BeanMetadataElement getPasswordAuthorizedClientProvider() { + String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + (ListableBeanFactory) this.beanFactory, PasswordOAuth2AuthorizedClientProvider.class, true, true); + if (beanNames.length == 1) { + return new RuntimeBeanReference(beanNames[0]); + } + + BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder + .genericBeanDefinition(PasswordOAuth2AuthorizedClientProvider.class); + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2PasswordGrantRequest.class); + beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, + resolvableType, true, true); + if (beanNames.length == 1) { + beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + } + return beanDefinitionBuilder.getBeanDefinition(); + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java new file mode 100644 index 00000000000..476b708610b --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java @@ -0,0 +1,218 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configuration; + +import java.util.Arrays; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +/** + * Tests for {@link OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizedClientManagerConfigurationTests { + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private OAuth2AuthorizedClientManager authorizedClientManager; + + @Autowired(required = false) + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider; + + @Autowired(required = false) + private RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider; + + @Autowired(required = false) + private ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider; + + @Autowired(required = false) + private PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider; + + @Test + public void loadContextWhenCustomRestOperationsThenConfigured() { + this.spring.register(CustomRestOperationsConfig.class).autowire(); + assertThat(this.authorizedClientManager).isNotNull(); + } + + @Test + public void loadContextWhenCustomAuthorizedClientProvidersThenConfigured() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + assertThat(this.authorizedClientManager).isNotNull(); + } + + @Configuration + @EnableWebSecurity + static class CustomRestOperationsConfig extends OAuth2ClientBaseConfig { + + // TODO This needs to be autoconfigured in OAuth2LoginConfigurer and + // OAuth2ClientConfigurer + @Bean + OAuth2AccessTokenResponseClient authorizationCodeTokenResponseClient() { + DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); + tokenResponseClient.setRestOperations(restOperations()); + return spy(tokenResponseClient); + } + + @Bean + OAuth2AccessTokenResponseClient refreshTokenTokenResponseClient() { + DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + tokenResponseClient.setRestOperations(restOperations()); + return spy(tokenResponseClient); + } + + @Bean + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient() { + DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + tokenResponseClient.setRestOperations(restOperations()); + return spy(tokenResponseClient); + } + + @Bean + OAuth2AccessTokenResponseClient passwordTokenResponseClient() { + DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient(); + tokenResponseClient.setRestOperations(restOperations()); + return spy(tokenResponseClient); + } + + // NOTE: This is autoconfigured in OAuth2LoginConfigurer and + // OAuth2ClientConfigurer + @Bean + OAuth2UserService oauth2UserService() { + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRestOperations(restOperations()); + return spy(userService); + } + + // NOTE: This is autoconfigured in OAuth2LoginConfigurer and + // OAuth2ClientConfigurer + @Bean + OAuth2UserService oidcUserService() { + OidcUserService userService = new OidcUserService(); + userService.setOauth2UserService(oauth2UserService()); + return spy(userService); + } + + @Bean + RestOperations restOperations() { + // Minimum required configuration + RestTemplate restTemplate = new RestTemplate(Arrays.asList(new FormHttpMessageConverter(), + new OAuth2AccessTokenResponseHttpMessageConverter(), new MappingJackson2HttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + + // TODO Add custom configuration, eg. Proxy, TLS, etc + + return spy(restTemplate); + } + + } + + @Configuration + @EnableWebSecurity + static class CustomAuthorizedClientProvidersConfig extends OAuth2ClientBaseConfig { + + @Bean + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeProvider() { + return mock(AuthorizationCodeOAuth2AuthorizedClientProvider.class); + } + + @Bean + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenProvider() { + return mock(RefreshTokenOAuth2AuthorizedClientProvider.class); + } + + @Bean + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsProvider() { + return mock(ClientCredentialsOAuth2AuthorizedClientProvider.class); + } + + @Bean + PasswordOAuth2AuthorizedClientProvider passwordProvider() { + return mock(PasswordOAuth2AuthorizedClientProvider.class); + } + + } + + abstract static class OAuth2ClientBaseConfig { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests(authorize -> + authorize.anyRequest().authenticated()) + .oauth2Login(Customizer.withDefaults()) + .oauth2Client(Customizer.withDefaults()); + return http.build(); + // @formatter:on + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(OAuth2AuthorizedClientRepository.class); + } + + } + +}