diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 946b6dcd4e4..e86380184e3 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -16,22 +16,25 @@ package org.springframework.security.config.annotation.web.configuration; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; +import java.util.Set; +import java.util.function.Consumer; -import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.BeanInitializationException; import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; -import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; -import org.springframework.beans.factory.support.ManagedList; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -43,11 +46,12 @@ import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; @@ -112,16 +116,12 @@ OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() { @Configuration(proxyBeanMethods = false) static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer { - private ClientRegistrationRepository clientRegistrationRepository; - - private OAuth2AuthorizedClientRepository authorizedClientRepository; - - private OAuth2AccessTokenResponseClient accessTokenResponseClient; - private OAuth2AuthorizedClientManager authorizedClientManager; private SecurityContextHolderStrategy securityContextHolderStrategy; + private OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar; + @Override public void addArgumentResolvers(List argumentResolvers) { OAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager(); @@ -135,26 +135,6 @@ public void addArgumentResolvers(List argumentRes } } - @Autowired(required = false) - void setClientRegistrationRepository(List clientRegistrationRepositories) { - if (clientRegistrationRepositories.size() == 1) { - this.clientRegistrationRepository = clientRegistrationRepositories.get(0); - } - } - - @Autowired(required = false) - void setAuthorizedClientRepository(List authorizedClientRepositories) { - if (authorizedClientRepositories.size() == 1) { - this.authorizedClientRepository = authorizedClientRepositories.get(0); - } - } - - @Autowired(required = false) - void setAccessTokenResponseClient( - OAuth2AccessTokenResponseClient accessTokenResponseClient) { - this.accessTokenResponseClient = accessTokenResponseClient; - } - @Autowired(required = false) void setAuthorizedClientManager(List authorizedClientManagers) { if (authorizedClientManagers.size() == 1) { @@ -167,33 +147,17 @@ void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) { this.securityContextHolderStrategy = strategy; } + @Autowired + void setAuthorizedClientManagerRegistrar( + OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar) { + this.authorizedClientManagerRegistrar = authorizedClientManagerRegistrar; + } + private OAuth2AuthorizedClientManager getAuthorizedClientManager() { if (this.authorizedClientManager != null) { return this.authorizedClientManager; } - OAuth2AuthorizedClientManager authorizedClientManager = null; - if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { - if (this.accessTokenResponseClient != null) { - // @formatter:off - OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder - .builder() - .authorizationCode() - .refreshToken() - .clientCredentials((configurer) -> configurer.accessTokenResponseClient(this.accessTokenResponseClient)) - .password() - .build(); - // @formatter:on - DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); - defaultAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - authorizedClientManager = defaultAuthorizedClientManager; - } - else { - authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); - } - } - return authorizedClientManager; + return this.authorizedClientManagerRegistrar.getAuthorizedClientManagerIfAvailable(); } } @@ -203,36 +167,37 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { * definition, if not already present. * * @author Joe Grandja + * @author Steve Riesenberg * @since 6.2.0 */ - static class OAuth2AuthorizedClientManagerRegistrar + static final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + // @formatter:off + private static final Set> KNOWN_AUTHORIZED_CLIENT_PROVIDERS = Set.of( + AuthorizationCodeOAuth2AuthorizedClientProvider.class, + RefreshTokenOAuth2AuthorizedClientProvider.class, + ClientCredentialsOAuth2AuthorizedClientProvider.class, + PasswordOAuth2AuthorizedClientProvider.class, + JwtBearerOAuth2AuthorizedClientProvider.class + ); + // @formatter:on + private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); - private BeanFactory beanFactory; + private ListableBeanFactory beanFactory; @Override public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { - String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, true, true); - if (authorizedClientManagerBeanNames.length != 0) { - return; - } - - String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, true, true); - String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true); - if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) { + if (getBeanNamesForType(OAuth2AuthorizedClientManager.class).length != 0 + || getBeanNamesForType(ClientRegistrationRepository.class).length != 1 + || getBeanNamesForType(OAuth2AuthorizedClientRepository.class).length != 1) { return; } BeanDefinition beanDefinition = BeanDefinitionBuilder - .genericBeanDefinition(DefaultOAuth2AuthorizedClientManager.class) - .addConstructorArgReference(clientRegistrationRepositoryBeanNames[0]) - .addConstructorArgReference(authorizedClientRepositoryBeanNames[0]) - .addPropertyValue("authorizedClientProvider", getAuthorizedClientProvider()).getBeanDefinition(); + .genericBeanDefinition(OAuth2AuthorizedClientManager.class, this::getAuthorizedClientManager) + .getBeanDefinition(); registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry), beanDefinition); @@ -242,90 +207,200 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } - private BeanDefinition getAuthorizedClientProvider() { - ManagedList authorizedClientProviders = new ManagedList<>(); - authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider()); - authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider()); - authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider()); - authorizedClientProviders.add(getPasswordAuthorizedClientProvider()); - return BeanDefinitionBuilder.genericBeanDefinition(DelegatingOAuth2AuthorizedClientProvider.class) - .addConstructorArgValue(authorizedClientProviders).getBeanDefinition(); + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = (ListableBeanFactory) beanFactory; } - private BeanMetadataElement getAuthorizationCodeAuthorizedClientProvider() { - String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, AuthorizationCodeOAuth2AuthorizedClientProvider.class, true, - true); - if (beanNames.length == 1) { - return new RuntimeBeanReference(beanNames[0]); + OAuth2AuthorizedClientManager getAuthorizedClientManagerIfAvailable() { + if (getBeanNamesForType(ClientRegistrationRepository.class).length != 1 + || getBeanNamesForType(OAuth2AuthorizedClientRepository.class).length != 1) { + return null; } + return getAuthorizedClientManager(); + } - return BeanDefinitionBuilder.genericBeanDefinition(AuthorizationCodeOAuth2AuthorizedClientProvider.class) - .getBeanDefinition(); + private OAuth2AuthorizedClientManager getAuthorizedClientManager() { + ClientRegistrationRepository clientRegistrationRepository = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, ClientRegistrationRepository.class, true, true); + + OAuth2AuthorizedClientRepository authorizedClientRepository = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true); + + Collection authorizedClientProviderBeans = BeanFactoryUtils + .beansOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientProvider.class, true, true) + .values(); + + OAuth2AuthorizedClientProvider authorizedClientProvider; + if (hasDelegatingAuthorizedClientProvider(authorizedClientProviderBeans)) { + authorizedClientProvider = authorizedClientProviderBeans.iterator().next(); + } + else { + List authorizedClientProviders = new ArrayList<>(); + authorizedClientProviders + .add(getAuthorizationCodeAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders + .add(getClientCredentialsAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getPasswordAuthorizedClientProvider(authorizedClientProviderBeans)); + + OAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = getJwtBearerAuthorizedClientProvider( + authorizedClientProviderBeans); + if (jwtBearerAuthorizedClientProvider != null) { + authorizedClientProviders.add(jwtBearerAuthorizedClientProvider); + } + + authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans)); + authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + Consumer authorizedClientManagerConsumer = getBeanOfType( + ResolvableType.forClassWithGenerics(Consumer.class, DefaultOAuth2AuthorizedClientManager.class)); + if (authorizedClientManagerConsumer != null) { + authorizedClientManagerConsumer.accept(authorizedClientManager); + } + + return authorizedClientManager; } - private BeanMetadataElement getRefreshTokenAuthorizedClientProvider() { - String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, RefreshTokenOAuth2AuthorizedClientProvider.class, true, - true); - if (beanNames.length == 1) { - return new RuntimeBeanReference(beanNames[0]); + private boolean hasDelegatingAuthorizedClientProvider( + Collection authorizedClientProviders) { + if (authorizedClientProviders.size() != 1) { + return false; } + return authorizedClientProviders.iterator().next() instanceof DelegatingOAuth2AuthorizedClientProvider; + } - BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder - .genericBeanDefinition(RefreshTokenOAuth2AuthorizedClientProvider.class); - ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, - OAuth2RefreshTokenGrantRequest.class); - beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, - resolvableType, true, true); - if (beanNames.length == 1) { - beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + private OAuth2AuthorizedClientProvider getAuthorizationCodeAuthorizedClientProvider( + Collection authorizedClientProviders) { + AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, AuthorizationCodeOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider(); } - return beanDefinitionBuilder.getBeanDefinition(); + + return authorizedClientProvider; } - private BeanMetadataElement getClientCredentialsAuthorizedClientProvider() { - String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, ClientCredentialsOAuth2AuthorizedClientProvider.class, true, - true); - if (beanNames.length == 1) { - return new RuntimeBeanReference(beanNames[0]); + private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( + Collection authorizedClientProviders) { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, RefreshTokenOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); } - BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder - .genericBeanDefinition(ClientCredentialsOAuth2AuthorizedClientProvider.class); - ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, - OAuth2ClientCredentialsGrantRequest.class); - beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, - resolvableType, true, true); - if (beanNames.length == 1) { - beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2RefreshTokenGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } - return beanDefinitionBuilder.getBeanDefinition(); + + return authorizedClientProvider; } - private BeanMetadataElement getPasswordAuthorizedClientProvider() { - String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - (ListableBeanFactory) this.beanFactory, PasswordOAuth2AuthorizedClientProvider.class, true, true); - if (beanNames.length == 1) { - return new RuntimeBeanReference(beanNames[0]); + private OAuth2AuthorizedClientProvider getClientCredentialsAuthorizedClientProvider( + Collection authorizedClientProviders) { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, ClientCredentialsOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); } - BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder - .genericBeanDefinition(PasswordOAuth2AuthorizedClientProvider.class); - ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, - OAuth2PasswordGrantRequest.class); - beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, - resolvableType, true, true); - if (beanNames.length == 1) { - beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]); + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2ClientCredentialsGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } - return beanDefinitionBuilder.getBeanDefinition(); + + return authorizedClientProvider; } - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.beanFactory = beanFactory; + private OAuth2AuthorizedClientProvider getPasswordAuthorizedClientProvider( + Collection authorizedClientProviders) { + PasswordOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, PasswordOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider(); + } + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2PasswordGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider( + Collection authorizedClientProviders) { + JwtBearerOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, JwtBearerOAuth2AuthorizedClientProvider.class); + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + JwtBearerGrantRequest.class)); + if (accessTokenResponseClient != null) { + if (authorizedClientProvider == null) { + authorizedClientProvider = new JwtBearerOAuth2AuthorizedClientProvider(); + } + + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private List getAdditionalAuthorizedClientProviders( + Collection authorizedClientProviders) { + List additionalAuthorizedClientProviders = new ArrayList<>( + authorizedClientProviders); + additionalAuthorizedClientProviders + .removeIf((provider) -> KNOWN_AUTHORIZED_CLIENT_PROVIDERS.contains(provider.getClass())); + return additionalAuthorizedClientProviders; + } + + private T getAuthorizedClientProviderByType( + Collection authorizedClientProviders, Class providerClass) { + T authorizedClientProvider = null; + for (OAuth2AuthorizedClientProvider current : authorizedClientProviders) { + if (providerClass.isInstance(current)) { + assertAuthorizedClientProviderIsNull(authorizedClientProvider); + authorizedClientProvider = providerClass.cast(current); + } + } + return authorizedClientProvider; + } + + private static void assertAuthorizedClientProviderIsNull( + OAuth2AuthorizedClientProvider authorizedClientProvider) { + if (authorizedClientProvider != null) { + // @formatter:off + throw new BeanInitializationException(String.format( + "Unable to create an %s bean. Expected one bean of type %s, but found multiple. " + + "Please consider defining only a single bean of this type, or define an %s bean yourself.", + OAuth2AuthorizedClientManager.class.getName(), + authorizedClientProvider.getClass().getName(), + OAuth2AuthorizedClientManager.class.getName())); + // @formatter:on + } + } + + private String[] getBeanNamesForType(Class beanClass) { + return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanClass, true, true); + } + + private T getBeanOfType(ResolvableType resolvableType) { + ObjectProvider objectProvider = this.beanFactory.getBeanProvider(resolvableType, true); + return objectProvider.getIfAvailable(); } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index dfd84bcab83..df26c12b663 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.client; +import org.springframework.context.ApplicationContext; +import org.springframework.core.ResolvableType; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -307,7 +309,22 @@ private OAuth2AccessTokenResponseClient get if (this.accessTokenResponseClient != null) { return this.accessTokenResponseClient; } - return new DefaultAuthorizationCodeTokenResponseClient(); + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2AuthorizationCodeGrantRequest.class); + OAuth2AccessTokenResponseClient bean = getBeanOrNull(resolvableType); + return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient(); + } + + @SuppressWarnings("unchecked") + private T getBeanOrNull(ResolvableType type) { + ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class); + if (context != null) { + String[] names = context.getBeanNamesForType(type); + if (names.length == 1) { + return (T) context.getBean(names[0]); + } + } + return null; } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index b7a2ccc61fa..35288d69246 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -330,10 +330,7 @@ public void init(B http) throws Exception { super.init(http); } } - OAuth2AccessTokenResponseClient accessTokenResponseClient = this.tokenEndpointConfig.accessTokenResponseClient; - if (accessTokenResponseClient == null) { - accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); - } + OAuth2AccessTokenResponseClient accessTokenResponseClient = getAccessTokenResponseClient(); OAuth2UserService oauth2UserService = getOAuth2UserService(); OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = new OAuth2LoginAuthenticationProvider( accessTokenResponseClient, oauth2UserService); @@ -441,6 +438,16 @@ private GrantedAuthoritiesMapper getGrantedAuthoritiesMapperBean() { return (!grantedAuthoritiesMapperMap.isEmpty() ? grantedAuthoritiesMapperMap.values().iterator().next() : null); } + private OAuth2AccessTokenResponseClient getAccessTokenResponseClient() { + if (this.tokenEndpointConfig.accessTokenResponseClient != null) { + return this.tokenEndpointConfig.accessTokenResponseClient; + } + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2AuthorizationCodeGrantRequest.class); + OAuth2AccessTokenResponseClient bean = getBeanOrNull(resolvableType); + return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient(); + } + private OAuth2UserService getOidcUserService() { if (this.userInfoEndpointConfig.oidcUserService != null) { return this.userInfoEndpointConfig.oidcUserService; diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index ead9135d957..0df983ae6ff 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -420,6 +420,8 @@ private void registerOAuth2ClientPostProcessors() { this.pc.getReaderContext() .registerWithGeneratedName(new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class)); } + this.pc.getReaderContext() + .registerWithGeneratedName(new RootBeanDefinition(OAuth2AuthorizedClientManagerRegistrar.class)); } private void createSaml2LoginFilter(BeanReference authenticationManager, diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java new file mode 100644 index 00000000000..1adf961dea6 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java @@ -0,0 +1,287 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.http; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.annotation.AnnotationBeanNameGenerator; +import org.springframework.core.ResolvableType; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; + +/** + * A registrar for registering the default {@link OAuth2AuthorizedClientManager} bean + * definition, if not already present. + *

+ * Note: This class is a direct copy of + * {@link org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerRegistrar}. + * + * @author Joe Grandja + * @author Steve Riesenberg + * @since 6.2.0 + */ +final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + + // @formatter:off + private static final Set> KNOWN_AUTHORIZED_CLIENT_PROVIDERS = Set.of( + AuthorizationCodeOAuth2AuthorizedClientProvider.class, + RefreshTokenOAuth2AuthorizedClientProvider.class, + ClientCredentialsOAuth2AuthorizedClientProvider.class, + PasswordOAuth2AuthorizedClientProvider.class, + JwtBearerOAuth2AuthorizedClientProvider.class + ); + // @formatter:on + + private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + + private ListableBeanFactory beanFactory; + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (getBeanNamesForType(OAuth2AuthorizedClientManager.class).length != 0 + || getBeanNamesForType(ClientRegistrationRepository.class).length != 1 + || getBeanNamesForType(OAuth2AuthorizedClientRepository.class).length != 1) { + return; + } + + BeanDefinition beanDefinition = BeanDefinitionBuilder + .genericBeanDefinition(OAuth2AuthorizedClientManager.class, this::getAuthorizedClientManager) + .getBeanDefinition(); + + registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry), + beanDefinition); + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = (ListableBeanFactory) beanFactory; + } + + private OAuth2AuthorizedClientManager getAuthorizedClientManager() { + ClientRegistrationRepository clientRegistrationRepository = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, ClientRegistrationRepository.class, true, true); + + OAuth2AuthorizedClientRepository authorizedClientRepository = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true); + + Collection authorizedClientProviderBeans = BeanFactoryUtils + .beansOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientProvider.class, true, true) + .values(); + + OAuth2AuthorizedClientProvider authorizedClientProvider; + if (hasDelegatingAuthorizedClientProvider(authorizedClientProviderBeans)) { + authorizedClientProvider = authorizedClientProviderBeans.iterator().next(); + } + else { + List authorizedClientProviders = new ArrayList<>(); + authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getPasswordAuthorizedClientProvider(authorizedClientProviderBeans)); + + OAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = getJwtBearerAuthorizedClientProvider( + authorizedClientProviderBeans); + if (jwtBearerAuthorizedClientProvider != null) { + authorizedClientProviders.add(jwtBearerAuthorizedClientProvider); + } + + authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans)); + authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + Consumer authorizedClientManagerConsumer = getBeanOfType( + ResolvableType.forClassWithGenerics(Consumer.class, DefaultOAuth2AuthorizedClientManager.class)); + if (authorizedClientManagerConsumer != null) { + authorizedClientManagerConsumer.accept(authorizedClientManager); + } + + return authorizedClientManager; + } + + private boolean hasDelegatingAuthorizedClientProvider( + Collection authorizedClientProviders) { + if (authorizedClientProviders.size() != 1) { + return false; + } + return authorizedClientProviders.iterator().next() instanceof DelegatingOAuth2AuthorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getAuthorizationCodeAuthorizedClientProvider( + Collection authorizedClientProviders) { + AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, AuthorizationCodeOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider(); + } + + return authorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( + Collection authorizedClientProviders) { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, RefreshTokenOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + } + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2RefreshTokenGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getClientCredentialsAuthorizedClientProvider( + Collection authorizedClientProviders) { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, ClientCredentialsOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + } + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2ClientCredentialsGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getPasswordAuthorizedClientProvider( + Collection authorizedClientProviders) { + PasswordOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, PasswordOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider(); + } + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, + OAuth2PasswordGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private OAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider( + Collection authorizedClientProviders) { + JwtBearerOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, JwtBearerOAuth2AuthorizedClientProvider.class); + + OAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType(ResolvableType + .forClassWithGenerics(OAuth2AccessTokenResponseClient.class, JwtBearerGrantRequest.class)); + if (accessTokenResponseClient != null) { + if (authorizedClientProvider == null) { + authorizedClientProvider = new JwtBearerOAuth2AuthorizedClientProvider(); + } + + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private List getAdditionalAuthorizedClientProviders( + Collection authorizedClientProviders) { + List additionalAuthorizedClientProviders = new ArrayList<>( + authorizedClientProviders); + additionalAuthorizedClientProviders + .removeIf((provider) -> KNOWN_AUTHORIZED_CLIENT_PROVIDERS.contains(provider.getClass())); + return additionalAuthorizedClientProviders; + } + + private T getAuthorizedClientProviderByType( + Collection authorizedClientProviders, Class providerClass) { + T authorizedClientProvider = null; + for (OAuth2AuthorizedClientProvider current : authorizedClientProviders) { + if (providerClass.isInstance(current)) { + assertAuthorizedClientProviderIsNull(authorizedClientProvider); + authorizedClientProvider = providerClass.cast(current); + } + } + return authorizedClientProvider; + } + + private static void assertAuthorizedClientProviderIsNull(OAuth2AuthorizedClientProvider authorizedClientProvider) { + if (authorizedClientProvider != null) { + // @formatter:off + throw new BeanInitializationException(String.format( + "Unable to create an %s bean. Expected one bean of type %s, but found multiple. " + + "Please consider defining only a single bean of this type, or define an %s bean yourself.", + OAuth2AuthorizedClientManager.class.getName(), + authorizedClientProvider.getClass().getName(), + OAuth2AuthorizedClientManager.class.getName())); + // @formatter:on + } + } + + private String[] getBeanNamesForType(Class beanClass) { + return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanClass, false, false); + } + + private T getBeanOfType(ResolvableType resolvableType) { + ObjectProvider objectProvider = this.beanFactory.getBeanProvider(resolvableType, true); + return objectProvider.getIfAvailable(); + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java index 476b708610b..79112da4ce9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,149 +16,374 @@ package org.springframework.security.config.annotation.web.configuration; +import java.time.Duration; +import java.time.Instant; import java.util.Arrays; - +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.http.converter.FormHttpMessageConverter; -import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.web.SecurityFilterChain; -import org.springframework.web.client.RestOperations; -import org.springframework.web.client.RestTemplate; +import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration}. * * @author Joe Grandja + * @author Steve Riesenberg */ public class OAuth2AuthorizedClientManagerConfigurationTests { + private static OAuth2AccessTokenResponseClient MOCK_RESPONSE_CLIENT; + public final SpringTestContext spring = new SpringTestContext(this); @Autowired private OAuth2AuthorizedClientManager authorizedClientManager; + @Autowired + private ClientRegistrationRepository clientRegistrationRepository; + + @Autowired + private OAuth2AuthorizedClientRepository authorizedClientRepository; + @Autowired(required = false) private AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider; - @Autowired(required = false) - private RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider; + private MockHttpServletRequest request; - @Autowired(required = false) - private ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider; + private MockHttpServletResponse response; - @Autowired(required = false) - private PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider; + @BeforeEach + @SuppressWarnings("unchecked") + public void setUp() { + MOCK_RESPONSE_CLIENT = mock(OAuth2AccessTokenResponseClient.class); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + } @Test - public void loadContextWhenCustomRestOperationsThenConfigured() { - this.spring.register(CustomRestOperationsConfig.class).autowire(); + public void loadContextWhenOAuth2ClientEnabledThenConfigured() { + this.spring.register(MinimalOAuth2ClientConfig.class).autowire(); assertThat(this.authorizedClientManager).isNotNull(); } @Test - public void loadContextWhenCustomAuthorizedClientProvidersThenConfigured() { + public void authorizeWhenAuthorizationCodeAuthorizedClientProviderBeanThenUsed() { this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); - assertThat(this.authorizedClientManager).isNotNull(); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("google") + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .extracting(OAuth2AuthorizationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("client_authorization_required"); + // @formatter:on + + verify(this.authorizationCodeAuthorizedClientProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + @Test + public void authorizeWhenRefreshTokenAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testRefreshTokenGrant(); + } + + @Test + public void authorizeWhenRefreshTokenAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testRefreshTokenGrant(); + } + + private void testRefreshTokenGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); + OAuth2AuthorizedClient existingAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, + authentication.getName(), getExpiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + this.authorizedClientRepository.saveAuthorizedClient(existingAuthorizedClient, authentication, this.request, + this.response); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withAuthorizedClient(existingAuthorizedClient) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2RefreshTokenGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(grantRequest.getAccessToken()).isEqualTo(existingAuthorizedClient.getAccessToken()); + assertThat(grantRequest.getRefreshToken()).isEqualTo(existingAuthorizedClient.getRefreshToken()); + } + + @Test + public void authorizeWhenClientCredentialsAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testClientCredentialsGrant(); + } + + @Test + public void authorizeWhenClientCredentialsAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testClientCredentialsGrant(); + } + + private void testClientCredentialsGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2ClientCredentialsGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2ClientCredentialsGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + } + + @Test + public void authorizeWhenPasswordAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testPasswordGrant(); + } + + @Test + public void authorizeWhenPasswordAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testPasswordGrant(); + } + + private void testPasswordGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2PasswordGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + this.request.setParameter(OAuth2ParameterNames.USERNAME, "user"); + this.request.setParameter(OAuth2ParameterNames.PASSWORD, "password"); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2PasswordGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2PasswordGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.PASSWORD); + assertThat(grantRequest.getUsername()).isEqualTo("user"); + assertThat(grantRequest.getPassword()).isEqualTo("password"); + } + + @Test + public void authorizeWhenJwtBearerAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testJwtBearerGrant(); + } + + @Test + public void authorizeWhenJwtBearerAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testJwtBearerGrant(); + } + + private void testJwtBearerGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(JwtBearerGrantRequest.class))).willReturn(accessTokenResponse); + + JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt()); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor.forClass(JwtBearerGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + JwtBearerGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.JWT_BEARER); + assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user"); + } + + private static OAuth2AccessToken getExpiredAccessToken() { + Instant expiresAt = Instant.now().minusSeconds(60); + Instant issuedAt = expiresAt.minus(Duration.ofDays(1)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", issuedAt, expiresAt, + new HashSet<>(Arrays.asList("read", "write"))); + } + + private static Jwt getJwt() { + Instant issuedAt = Instant.now(); + return new Jwt("token", issuedAt, issuedAt.plusSeconds(300), + Collections.singletonMap(JoseHeaderNames.ALG, "RS256"), + Collections.singletonMap(JwtClaimNames.SUB, "user")); + } + + @Configuration + @EnableWebSecurity + static class MinimalOAuth2ClientConfig extends OAuth2ClientBaseConfig { + } @Configuration @EnableWebSecurity - static class CustomRestOperationsConfig extends OAuth2ClientBaseConfig { + static class CustomAccessTokenResponseClientsConfig extends OAuth2ClientBaseConfig { - // TODO This needs to be autoconfigured in OAuth2LoginConfigurer and - // OAuth2ClientConfigurer @Bean OAuth2AccessTokenResponseClient authorizationCodeTokenResponseClient() { - DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); - tokenResponseClient.setRestOperations(restOperations()); - return spy(tokenResponseClient); + return new MockAuthorizationCodeClient(); } @Bean OAuth2AccessTokenResponseClient refreshTokenTokenResponseClient() { - DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); - tokenResponseClient.setRestOperations(restOperations()); - return spy(tokenResponseClient); + return new MockRefreshTokenClient(); } @Bean OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient() { - DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); - tokenResponseClient.setRestOperations(restOperations()); - return spy(tokenResponseClient); + return new MockClientCredentialsClient(); } @Bean OAuth2AccessTokenResponseClient passwordTokenResponseClient() { - DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient(); - tokenResponseClient.setRestOperations(restOperations()); - return spy(tokenResponseClient); + return new MockPasswordClient(); } - // NOTE: This is autoconfigured in OAuth2LoginConfigurer and - // OAuth2ClientConfigurer @Bean - OAuth2UserService oauth2UserService() { - DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); - userService.setRestOperations(restOperations()); - return spy(userService); + OAuth2AccessTokenResponseClient jwtBearerTokenResponseClient() { + return new MockJwtBearerClient(); } - // NOTE: This is autoconfigured in OAuth2LoginConfigurer and - // OAuth2ClientConfigurer @Bean - OAuth2UserService oidcUserService() { - OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(oauth2UserService()); - return spy(userService); + OAuth2UserService oauth2UserService() { + return mock(DefaultOAuth2UserService.class); } @Bean - RestOperations restOperations() { - // Minimum required configuration - RestTemplate restTemplate = new RestTemplate(Arrays.asList(new FormHttpMessageConverter(), - new OAuth2AccessTokenResponseHttpMessageConverter(), new MappingJackson2HttpMessageConverter())); - restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); - - // TODO Add custom configuration, eg. Proxy, TLS, etc - - return spy(restTemplate); + OAuth2UserService oidcUserService() { + return mock(OidcUserService.class); } } @@ -169,22 +394,35 @@ static class CustomAuthorizedClientProvidersConfig extends OAuth2ClientBaseConfi @Bean AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeProvider() { - return mock(AuthorizationCodeOAuth2AuthorizedClientProvider.class); + return spy(new AuthorizationCodeOAuth2AuthorizedClientProvider()); } @Bean RefreshTokenOAuth2AuthorizedClientProvider refreshTokenProvider() { - return mock(RefreshTokenOAuth2AuthorizedClientProvider.class); + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockRefreshTokenClient()); + return authorizedClientProvider; } @Bean ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsProvider() { - return mock(ClientCredentialsOAuth2AuthorizedClientProvider.class); + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockClientCredentialsClient()); + return authorizedClientProvider; } @Bean PasswordOAuth2AuthorizedClientProvider passwordProvider() { - return mock(PasswordOAuth2AuthorizedClientProvider.class); + PasswordOAuth2AuthorizedClientProvider authorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockPasswordClient()); + return authorizedClientProvider; + } + + @Bean + JwtBearerOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider() { + JwtBearerOAuth2AuthorizedClientProvider authorizedClientProvider = new JwtBearerOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockJwtBearerClient()); + return authorizedClientProvider; } } @@ -195,8 +433,7 @@ abstract static class OAuth2ClientBaseConfig { SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeHttpRequests(authorize -> - authorize.anyRequest().authenticated()) + .authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated()) .oauth2Login(Customizer.withDefaults()) .oauth2Client(Customizer.withDefaults()); return http.build(); @@ -205,7 +442,29 @@ SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { @Bean ClientRegistrationRepository clientRegistrationRepository() { - return mock(ClientRegistrationRepository.class); + // @formatter:off + return new InMemoryClientRegistrationRepository(Arrays.asList( + CommonOAuth2Provider.GOOGLE.getBuilder("google") + .clientId("google-client-id") + .clientSecret("google-client-secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .build(), + CommonOAuth2Provider.GITHUB.getBuilder("github") + .clientId("github-client-id") + .clientSecret("github-client-secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(), + CommonOAuth2Provider.FACEBOOK.getBuilder("facebook") + .clientId("facebook-client-id") + .clientSecret("facebook-client-secret") + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .build(), + CommonOAuth2Provider.OKTA.getBuilder("okta") + .clientId("okta-client-id") + .clientSecret("okta-client-secret") + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .build())); + // @formatter:on } @Bean @@ -213,6 +472,76 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); } + @Bean + Consumer authorizedClientManagerConsumer() { + return (authorizedClientManager) -> authorizedClientManager + .setContextAttributesMapper((authorizeRequest) -> { + HttpServletRequest request = Objects + .requireNonNull(authorizeRequest.getAttribute(HttpServletRequest.class.getName())); + String username = request.getParameter(OAuth2ParameterNames.USERNAME); + String password = request.getParameter(OAuth2ParameterNames.PASSWORD); + + Map attributes = Collections.emptyMap(); + if (StringUtils.hasText(username) && StringUtils.hasText(password)) { + attributes = new HashMap<>(); + attributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); + attributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); + } + + return attributes; + }); + } + + } + + private static class MockAuthorizationCodeClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockRefreshTokenClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockClientCredentialsClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockPasswordClient implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2PasswordGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockJwtBearerClient implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(JwtBearerGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 08df86cdb76..ea7ac5ac598 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -175,9 +175,10 @@ public void loadContextWhenClientRegistrationRepositoryRegisteredTwiceThenThrowN @Test public void loadContextWhenAccessTokenResponseClientRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { // @formatter:off - assertThatExceptionOfType(Exception.class) + assertThatExceptionOfType(BeanCreationException.class) .isThrownBy(() -> this.spring.register(AccessTokenResponseClientRegisteredTwiceConfig.class).autowire()) - .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) + .havingRootCause() + .isInstanceOf(NoUniqueBeanDefinitionException.class) .withMessageContaining( "expected single matching bean but found 2: accessTokenResponseClient1,accessTokenResponseClient2"); // @formatter:on diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests.java new file mode 100644 index 00000000000..d79c084c01d --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests.java @@ -0,0 +1,475 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.http; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OAuth2AuthorizedClientManagerRegistrar}. + * + * @author Steve Riesenberg + */ +public class OAuth2AuthorizedClientManagerRegistrarTests { + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests"; + + private static OAuth2AccessTokenResponseClient MOCK_RESPONSE_CLIENT; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private OAuth2AuthorizedClientManager authorizedClientManager; + + @Autowired + private ClientRegistrationRepository clientRegistrationRepository; + + @Autowired + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + @Autowired(required = false) + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + @BeforeEach + @SuppressWarnings("unchecked") + public void setUp() { + MOCK_RESPONSE_CLIENT = mock(OAuth2AccessTokenResponseClient.class); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + } + + @Test + public void loadContextWhenOAuth2ClientEnabledThenConfigured() { + this.spring.configLocations(xml("minimal")).autowire(); + assertThat(this.authorizedClientManager).isNotNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAuthorizedClientProviderBeanThenUsed() { + this.spring.configLocations(xml("providers")).autowire(); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("google") + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .extracting(OAuth2AuthorizationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("client_authorization_required"); + // @formatter:on + + verify(this.authorizationCodeAuthorizedClientProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + @Test + public void authorizeWhenRefreshTokenAccessTokenResponseClientBeanThenUsed() { + this.spring.configLocations(xml("clients")).autowire(); + testRefreshTokenGrant(); + } + + @Test + public void authorizeWhenRefreshTokenAuthorizedClientProviderBeanThenUsed() { + this.spring.configLocations(xml("providers")).autowire(); + testRefreshTokenGrant(); + } + + private void testRefreshTokenGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); + OAuth2AuthorizedClient existingAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, + authentication.getName(), getExpiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + this.authorizedClientRepository.saveAuthorizedClient(existingAuthorizedClient, authentication, this.request, + this.response); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withAuthorizedClient(existingAuthorizedClient) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2RefreshTokenGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(grantRequest.getAccessToken()).isEqualTo(existingAuthorizedClient.getAccessToken()); + assertThat(grantRequest.getRefreshToken()).isEqualTo(existingAuthorizedClient.getRefreshToken()); + } + + @Test + public void authorizeWhenClientCredentialsAccessTokenResponseClientBeanThenUsed() { + this.spring.configLocations(xml("clients")).autowire(); + testClientCredentialsGrant(); + } + + @Test + public void authorizeWhenClientCredentialsAuthorizedClientProviderBeanThenUsed() { + this.spring.configLocations(xml("providers")).autowire(); + testClientCredentialsGrant(); + } + + private void testClientCredentialsGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2ClientCredentialsGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2ClientCredentialsGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + } + + @Test + public void authorizeWhenPasswordAccessTokenResponseClientBeanThenUsed() { + this.spring.configLocations(xml("clients")).autowire(); + testPasswordGrant(); + } + + @Test + public void authorizeWhenPasswordAuthorizedClientProviderBeanThenUsed() { + this.spring.configLocations(xml("providers")).autowire(); + testPasswordGrant(); + } + + private void testPasswordGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2PasswordGrantRequest.class))) + .willReturn(accessTokenResponse); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + this.request.setParameter(OAuth2ParameterNames.USERNAME, "user"); + this.request.setParameter(OAuth2ParameterNames.PASSWORD, "password"); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2PasswordGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2PasswordGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.PASSWORD); + assertThat(grantRequest.getUsername()).isEqualTo("user"); + assertThat(grantRequest.getPassword()).isEqualTo("password"); + } + + @Test + public void authorizeWhenJwtBearerAccessTokenResponseClientBeanThenUsed() { + this.spring.configLocations(xml("clients")).autowire(); + testJwtBearerGrant(); + } + + @Test + public void authorizeWhenJwtBearerAuthorizedClientProviderBeanThenUsed() { + this.spring.configLocations(xml("providers")).autowire(); + testJwtBearerGrant(); + } + + private void testJwtBearerGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(JwtBearerGrantRequest.class))).willReturn(accessTokenResponse); + + JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt()); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor.forClass(JwtBearerGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + JwtBearerGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.JWT_BEARER); + assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user"); + } + + private static OAuth2AccessToken getExpiredAccessToken() { + Instant expiresAt = Instant.now().minusSeconds(60); + Instant issuedAt = expiresAt.minus(Duration.ofDays(1)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", issuedAt, expiresAt, + new HashSet<>(Arrays.asList("read", "write"))); + } + + private static Jwt getJwt() { + Instant issuedAt = Instant.now(); + return new Jwt("token", issuedAt, issuedAt.plusSeconds(300), + Collections.singletonMap(JoseHeaderNames.ALG, "RS256"), + Collections.singletonMap(JwtClaimNames.SUB, "user")); + } + + private static String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; + } + + public static List getClientRegistrations() { + // @formatter:off + return Arrays.asList( + CommonOAuth2Provider.GOOGLE.getBuilder("google") + .clientId("google-client-id") + .clientSecret("google-client-secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .build(), + CommonOAuth2Provider.GITHUB.getBuilder("github") + .clientId("github-client-id") + .clientSecret("github-client-secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(), + CommonOAuth2Provider.FACEBOOK.getBuilder("facebook") + .clientId("facebook-client-id") + .clientSecret("facebook-client-secret") + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .build(), + CommonOAuth2Provider.OKTA.getBuilder("okta") + .clientId("okta-client-id") + .clientSecret("okta-client-secret") + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .build()); + // @formatter:on + } + + public static Consumer authorizedClientManagerConsumer() { + return (authorizedClientManager) -> authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> { + HttpServletRequest request = Objects + .requireNonNull(authorizeRequest.getAttribute(HttpServletRequest.class.getName())); + String username = request.getParameter(OAuth2ParameterNames.USERNAME); + String password = request.getParameter(OAuth2ParameterNames.PASSWORD); + + Map attributes = Collections.emptyMap(); + if (StringUtils.hasText(username) && StringUtils.hasText(password)) { + attributes = new HashMap<>(); + attributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); + attributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); + } + + return attributes; + }); + } + + public static AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider() { + return spy(new AuthorizationCodeOAuth2AuthorizedClientProvider()); + } + + public static RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider() { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(refreshTokenAccessTokenResponseClient()); + return authorizedClientProvider; + } + + public static MockRefreshTokenClient refreshTokenAccessTokenResponseClient() { + return new MockRefreshTokenClient(); + } + + public static ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider() { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(clientCredentialsAccessTokenResponseClient()); + return authorizedClientProvider; + } + + public static OAuth2AccessTokenResponseClient clientCredentialsAccessTokenResponseClient() { + return new MockClientCredentialsClient(); + } + + public static PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider() { + PasswordOAuth2AuthorizedClientProvider authorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(passwordAccessTokenResponseClient()); + return authorizedClientProvider; + } + + public static OAuth2AccessTokenResponseClient passwordAccessTokenResponseClient() { + return new MockPasswordClient(); + } + + public static JwtBearerOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider() { + JwtBearerOAuth2AuthorizedClientProvider authorizedClientProvider = new JwtBearerOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(jwtBearerAccessTokenResponseClient()); + return authorizedClientProvider; + } + + public static OAuth2AccessTokenResponseClient jwtBearerAccessTokenResponseClient() { + return new MockJwtBearerClient(); + } + + private static class MockAuthorizationCodeClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockRefreshTokenClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockClientCredentialsClient + implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockPasswordClient implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2PasswordGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + + private static class MockJwtBearerClient implements OAuth2AccessTokenResponseClient { + + @Override + public OAuth2AccessTokenResponse getTokenResponse(JwtBearerGrantRequest authorizationGrantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest); + } + + } + +} diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-clients.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-clients.xml new file mode 100644 index 00000000000..416520c6f7b --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-clients.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-minimal.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-minimal.xml new file mode 100644 index 00000000000..6efa77199a1 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-minimal.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-providers.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-providers.xml new file mode 100644 index 00000000000..1966d46371d --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-providers.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file