diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientConfiguration.java new file mode 100644 index 00000000000..56e25bd70c9 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientConfiguration.java @@ -0,0 +1,413 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.reactive; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.annotation.AnnotationBeanNameGenerator; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.core.ResolvableType; +import org.springframework.security.oauth2.client.AuthorizationCodeReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.PasswordReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.TokenExchangeReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; +import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.web.reactive.config.WebFluxConfigurer; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; + +/** + * {@link Configuration} for OAuth 2.0 Client support. + * + *

+ * This {@code Configuration} is conditionally imported by + * {@link ReactiveOAuth2ClientImportSelector} when the + * {@code spring-security-oauth2-client} module is present on the classpath. + * + * @author Steve Riesenberg + * @since 6.3 + * @see ReactiveOAuth2ClientImportSelector + */ +@Import({ ReactiveOAuth2ClientConfiguration.ReactiveOAuth2AuthorizedClientManagerConfiguration.class, + ReactiveOAuth2ClientConfiguration.OAuth2ClientWebFluxSecurityConfiguration.class }) +final class ReactiveOAuth2ClientConfiguration { + + @Configuration + static class ReactiveOAuth2AuthorizedClientManagerConfiguration { + + @Bean(name = ReactiveOAuth2AuthorizedClientManagerRegistrar.BEAN_NAME) + ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() { + return new ReactiveOAuth2AuthorizedClientManagerRegistrar(); + } + + } + + @Configuration(proxyBeanMethods = false) + static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer { + + private ReactiveOAuth2AuthorizedClientManager authorizedClientManager; + + private ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar; + + @Override + public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { + ReactiveOAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager(); + if (authorizedClientManager != null) { + configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager)); + } + } + + @Autowired(required = false) + void setAuthorizedClientManager(List authorizedClientManager) { + if (authorizedClientManager.size() == 1) { + this.authorizedClientManager = authorizedClientManager.get(0); + } + } + + @Autowired + void setAuthorizedClientManagerRegistrar( + ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar) { + this.authorizedClientManagerRegistrar = authorizedClientManagerRegistrar; + } + + private ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() { + if (this.authorizedClientManager != null) { + return this.authorizedClientManager; + } + return this.authorizedClientManagerRegistrar.getAuthorizedClientManagerIfAvailable(); + } + + } + + /** + * A registrar for registering the default + * {@link ReactiveOAuth2AuthorizedClientManager} bean definition, if not already + * present. + */ + static final class ReactiveOAuth2AuthorizedClientManagerRegistrar + implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + + static final String BEAN_NAME = "authorizedClientManagerRegistrar"; + + static final String FACTORY_METHOD_NAME = "getAuthorizedClientManager"; + + // @formatter:off + private static final Set> KNOWN_AUTHORIZED_CLIENT_PROVIDERS = Set.of( + AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class, + RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, + ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, + PasswordReactiveOAuth2AuthorizedClientProvider.class, + JwtBearerReactiveOAuth2AuthorizedClientProvider.class, + TokenExchangeReactiveOAuth2AuthorizedClientProvider.class + ); + // @formatter:on + + private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + + private ListableBeanFactory beanFactory; + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (getBeanNamesForType(ReactiveOAuth2AuthorizedClientManager.class).length != 0 + || getBeanNamesForType(ReactiveClientRegistrationRepository.class).length != 1 + || getBeanNamesForType(ServerOAuth2AuthorizedClientRepository.class).length != 1 + && getBeanNamesForType(ReactiveOAuth2AuthorizedClientService.class).length != 1) { + return; + } + + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(ReactiveOAuth2AuthorizedClientManager.class) + .setFactoryMethodOnBean(FACTORY_METHOD_NAME, BEAN_NAME) + .getBeanDefinition(); + + registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry), + beanDefinition); + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = (ListableBeanFactory) beanFactory; + } + + ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManagerIfAvailable() { + if (getBeanNamesForType(ReactiveClientRegistrationRepository.class).length != 1 + || getBeanNamesForType(ServerOAuth2AuthorizedClientRepository.class).length != 1 + && getBeanNamesForType(ReactiveOAuth2AuthorizedClientService.class).length != 1) { + return null; + } + return getAuthorizedClientManager(); + } + + ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() { + ReactiveClientRegistrationRepository clientRegistrationRepository = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, ReactiveClientRegistrationRepository.class, true, true); + + ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + try { + authorizedClientRepository = BeanFactoryUtils.beanOfTypeIncludingAncestors(this.beanFactory, + ServerOAuth2AuthorizedClientRepository.class, true, true); + } + catch (NoSuchBeanDefinitionException ex) { + ReactiveOAuth2AuthorizedClientService authorizedClientService = BeanFactoryUtils + .beanOfTypeIncludingAncestors(this.beanFactory, ReactiveOAuth2AuthorizedClientService.class, true, + true); + authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( + authorizedClientService); + } + + Collection authorizedClientProviderBeans = BeanFactoryUtils + .beansOfTypeIncludingAncestors(this.beanFactory, ReactiveOAuth2AuthorizedClientProvider.class, true, + true) + .values(); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + if (hasDelegatingAuthorizedClientProvider(authorizedClientProviderBeans)) { + authorizedClientProvider = authorizedClientProviderBeans.iterator().next(); + } + else { + List authorizedClientProviders = new ArrayList<>(); + authorizedClientProviders + .add(getAuthorizationCodeAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders + .add(getClientCredentialsAuthorizedClientProvider(authorizedClientProviderBeans)); + authorizedClientProviders.add(getPasswordAuthorizedClientProvider(authorizedClientProviderBeans)); + + ReactiveOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = getJwtBearerAuthorizedClientProvider( + authorizedClientProviderBeans); + if (jwtBearerAuthorizedClientProvider != null) { + authorizedClientProviders.add(jwtBearerAuthorizedClientProvider); + } + + ReactiveOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider( + authorizedClientProviderBeans); + if (tokenExchangeAuthorizedClientProvider != null) { + authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider); + } + + authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans)); + authorizedClientProvider = new DelegatingReactiveOAuth2AuthorizedClientProvider( + authorizedClientProviders); + } + + DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + Consumer authorizedClientManagerConsumer = getBeanOfType( + ResolvableType.forClassWithGenerics(Consumer.class, + DefaultReactiveOAuth2AuthorizedClientManager.class)); + if (authorizedClientManagerConsumer != null) { + authorizedClientManagerConsumer.accept(authorizedClientManager); + } + + return authorizedClientManager; + } + + private boolean hasDelegatingAuthorizedClientProvider( + Collection authorizedClientProviders) { + if (authorizedClientProviders.size() != 1) { + return false; + } + return authorizedClientProviders.iterator() + .next() instanceof DelegatingReactiveOAuth2AuthorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getAuthorizationCodeAuthorizedClientProvider( + Collection authorizedClientProviders) { + AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider(); + } + + return authorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( + Collection authorizedClientProviders) { + RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, RefreshTokenReactiveOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); + } + + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + OAuth2RefreshTokenGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getClientCredentialsAuthorizedClientProvider( + Collection authorizedClientProviders) { + ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider(); + } + + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + OAuth2ClientCredentialsGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getPasswordAuthorizedClientProvider( + Collection authorizedClientProviders) { + PasswordReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, PasswordReactiveOAuth2AuthorizedClientProvider.class); + if (authorizedClientProvider == null) { + authorizedClientProvider = new PasswordReactiveOAuth2AuthorizedClientProvider(); + } + + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + OAuth2PasswordGrantRequest.class)); + if (accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider( + Collection authorizedClientProviders) { + JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, JwtBearerReactiveOAuth2AuthorizedClientProvider.class); + + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + JwtBearerGrantRequest.class)); + if (accessTokenResponseClient != null) { + if (authorizedClientProvider == null) { + authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider(); + } + + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private ReactiveOAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider( + Collection authorizedClientProviders) { + TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType( + authorizedClientProviders, TokenExchangeReactiveOAuth2AuthorizedClientProvider.class); + + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOfType( + ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + TokenExchangeGrantRequest.class)); + if (accessTokenResponseClient != null) { + if (authorizedClientProvider == null) { + authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider(); + } + + authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); + } + + return authorizedClientProvider; + } + + private List getAdditionalAuthorizedClientProviders( + Collection authorizedClientProviders) { + List additionalAuthorizedClientProviders = new ArrayList<>( + authorizedClientProviders); + additionalAuthorizedClientProviders + .removeIf((provider) -> KNOWN_AUTHORIZED_CLIENT_PROVIDERS.contains(provider.getClass())); + return additionalAuthorizedClientProviders; + } + + private T getAuthorizedClientProviderByType( + Collection authorizedClientProviders, Class providerClass) { + T authorizedClientProvider = null; + for (ReactiveOAuth2AuthorizedClientProvider current : authorizedClientProviders) { + if (providerClass.isInstance(current)) { + assertAuthorizedClientProviderIsNull(authorizedClientProvider); + authorizedClientProvider = providerClass.cast(current); + } + } + return authorizedClientProvider; + } + + private static void assertAuthorizedClientProviderIsNull( + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { + if (authorizedClientProvider != null) { + // @formatter:off + throw new BeanInitializationException(String.format( + "Unable to create a %s bean. Expected one bean of type %s, but found multiple. " + + "Please consider defining only a single bean of this type, or define a %s bean yourself.", + ReactiveOAuth2AuthorizedClientManager.class.getName(), + authorizedClientProvider.getClass().getName(), + ReactiveOAuth2AuthorizedClientManager.class.getName())); + // @formatter:on + } + } + + private String[] getBeanNamesForType(Class beanClass) { + return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanClass, true, true); + } + + private T getBeanOfType(ResolvableType resolvableType) { + ObjectProvider objectProvider = this.beanFactory.getBeanProvider(resolvableType, true); + return objectProvider.getIfAvailable(); + } + + } + +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java index 9a1781b93bf..5e73fb7c894 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,27 +16,13 @@ package org.springframework.security.config.annotation.web.reactive; -import java.util.List; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; -import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; -import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.ClassUtils; -import org.springframework.web.reactive.config.WebFluxConfigurer; -import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; /** - * {@link Configuration} for OAuth 2.0 Client support. + * Used by {@link EnableWebFluxSecurity} to conditionally import + * {@link ReactiveOAuth2ClientConfiguration}. * *

* This {@code Configuration} is imported by {@link EnableWebFluxSecurity} @@ -60,85 +46,8 @@ public String[] selectImports(AnnotationMetadata importingClassMetadata) { if (!oauth2ClientPresent) { return new String[0]; } - return new String[] { "org.springframework.security.config.annotation.web.reactive." - + "ReactiveOAuth2ClientImportSelector$OAuth2ClientWebFluxSecurityConfiguration" }; - } - - @Configuration(proxyBeanMethods = false) - static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer { - - private ReactiveClientRegistrationRepository clientRegistrationRepository; - - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - - private ReactiveOAuth2AuthorizedClientService authorizedClientService; - - private ReactiveOAuth2AuthorizedClientManager authorizedClientManager; - - @Override - public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { - ReactiveOAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager(); - if (authorizedClientManager != null) { - configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager)); - } - } - - @Autowired(required = false) - void setClientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - } - - @Autowired(required = false) - void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.authorizedClientRepository = authorizedClientRepository; - } - - @Autowired(required = false) - void setAuthorizedClientService(List authorizedClientService) { - if (authorizedClientService.size() == 1) { - this.authorizedClientService = authorizedClientService.get(0); - } - } - - @Autowired(required = false) - void setAuthorizedClientManager(List authorizedClientManager) { - if (authorizedClientManager.size() == 1) { - this.authorizedClientManager = authorizedClientManager.get(0); - } - } - - private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { - if (this.authorizedClientRepository != null) { - return this.authorizedClientRepository; - } - if (this.authorizedClientService != null) { - return new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(this.authorizedClientService); - } - return null; - } - - private ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() { - if (this.authorizedClientManager != null) { - return this.authorizedClientManager; - } - ReactiveOAuth2AuthorizedClientManager authorizedClientManager = null; - if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder - .builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); - DefaultReactiveOAuth2AuthorizedClientManager defaultReactiveOAuth2AuthorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, getAuthorizedClientRepository()); - defaultReactiveOAuth2AuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - authorizedClientManager = defaultReactiveOAuth2AuthorizedClientManager; - } - - return authorizedClientManager; - } - + return new String[] { + "org.springframework.security.config.annotation.web.reactive.ReactiveOAuth2ClientConfiguration" }; } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2AuthorizedClientManagerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2AuthorizedClientManagerConfigurationTests.java new file mode 100644 index 00000000000..dd7698e98db --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2AuthorizedClientManagerConfigurationTests.java @@ -0,0 +1,589 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.reactive; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.oauth2.client.AuthorizationCodeReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.PasswordReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.TokenExchangeReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for + * {@link ReactiveOAuth2ClientConfiguration.ReactiveOAuth2AuthorizedClientManagerConfiguration}. + * + * @author Steve Riesenberg + */ +public class ReactiveOAuth2AuthorizedClientManagerConfigurationTests { + + private static ReactiveOAuth2AccessTokenResponseClient MOCK_RESPONSE_CLIENT; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private ReactiveOAuth2AuthorizedClientManager authorizedClientManager; + + @Autowired + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + @Autowired(required = false) + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + @Autowired(required = false) + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + + @Autowired(required = false) + private AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider; + + private MockServerWebExchange exchange; + + @BeforeEach + @SuppressWarnings("unchecked") + public void setUp() { + MOCK_RESPONSE_CLIENT = mock(ReactiveOAuth2AccessTokenResponseClient.class); + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + this.exchange = MockServerWebExchange.builder(request).build(); + } + + @Test + public void loadContextWhenOAuth2ClientEnabledThenConfigured() { + this.spring.register(MinimalOAuth2ClientConfig.class).autowire(); + assertThat(this.authorizedClientManager).isNotNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("google") + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .extracting(OAuth2AuthorizationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("client_authorization_required"); + // @formatter:on + + verify(this.authorizationCodeAuthorizedClientProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + @Test + public void authorizeWhenAuthorizedClientServiceBeanThenUsed() { + this.spring.register(CustomAuthorizedClientServiceConfig.class).autowire(); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("google") + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .extracting(OAuth2AuthorizationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo("client_authorization_required"); + // @formatter:on + + verify(this.authorizedClientService).loadAuthorizedClient(authorizeRequest.getClientRegistrationId(), + authentication.getName()); + } + + @Test + public void authorizeWhenRefreshTokenAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testRefreshTokenGrant(); + } + + @Test + public void authorizeWhenRefreshTokenAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testRefreshTokenGrant(); + } + + private void testRefreshTokenGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google") + .block(); + assertThat(clientRegistration).isNotNull(); + OAuth2AuthorizedClient existingAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, + authentication.getName(), getExpiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + this.authorizedClientRepository.saveAuthorizedClient(existingAuthorizedClient, authentication, this.exchange) + .block(); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2RefreshTokenGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(grantRequest.getAccessToken()).isEqualTo(existingAuthorizedClient.getAccessToken()); + assertThat(grantRequest.getRefreshToken()).isEqualTo(existingAuthorizedClient.getRefreshToken()); + } + + @Test + public void authorizeWhenClientCredentialsAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testClientCredentialsGrant(); + } + + @Test + public void authorizeWhenClientCredentialsAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testClientCredentialsGrant(); + } + + private void testClientCredentialsGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github") + .block(); + assertThat(clientRegistration).isNotNull(); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2ClientCredentialsGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2ClientCredentialsGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + } + + @Test + public void authorizeWhenPasswordAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testPasswordGrant(); + } + + @Test + public void authorizeWhenPasswordAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testPasswordGrant(); + } + + private void testPasswordGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2PasswordGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook") + .block(); + assertThat(clientRegistration).isNotNull(); + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body("username=user&password=password"); + this.exchange = MockServerWebExchange.builder(request).build(); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(OAuth2PasswordGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + OAuth2PasswordGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.PASSWORD); + assertThat(grantRequest.getUsername()).isEqualTo("user"); + assertThat(grantRequest.getPassword()).isEqualTo("password"); + } + + @Test + public void authorizeWhenJwtBearerAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testJwtBearerGrant(); + } + + @Test + public void authorizeWhenJwtBearerAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testJwtBearerGrant(); + } + + private void testJwtBearerGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(JwtBearerGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + + JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt()); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta").block(); + assertThat(clientRegistration).isNotNull(); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor.forClass(JwtBearerGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + JwtBearerGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.JWT_BEARER); + assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user"); + } + + @Test + public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() { + this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire(); + testTokenExchangeGrant(); + } + + @Test + public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() { + this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire(); + testTokenExchangeGrant(); + } + + private void testTokenExchangeGrant() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + + JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt()); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0").block(); + assertThat(clientRegistration).isNotNull(); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(clientRegistration.getRegistrationId()) + .principal(authentication) + .attribute(ServerWebExchange.class.getName(), this.exchange) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + assertThat(authorizedClient).isNotNull(); + + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture()); + + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getClientRegistration().getRegistrationId()) + .isEqualTo(clientRegistration.getRegistrationId()); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE); + assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken()); + } + + private static OAuth2AccessToken getExpiredAccessToken() { + Instant expiresAt = Instant.now().minusSeconds(60); + Instant issuedAt = expiresAt.minus(Duration.ofDays(1)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", issuedAt, expiresAt, + new HashSet<>(Arrays.asList("read", "write"))); + } + + private static Jwt getJwt() { + Instant issuedAt = Instant.now(); + return new Jwt("token", issuedAt, issuedAt.plusSeconds(300), + Collections.singletonMap(JoseHeaderNames.ALG, "RS256"), + Collections.singletonMap(JwtClaimNames.SUB, "user")); + } + + @Configuration + @EnableWebFluxSecurity + static class MinimalOAuth2ClientConfig extends OAuth2ClientBaseConfig { + + @Bean + ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { + return new WebSessionServerOAuth2AuthorizedClientRepository(); + } + + } + + @Configuration + @EnableWebFluxSecurity + static class CustomAuthorizedClientServiceConfig extends OAuth2ClientBaseConfig { + + @Bean + ReactiveOAuth2AuthorizedClientService authorizedClientService() { + ReactiveOAuth2AuthorizedClientService authorizedClientService = mock( + ReactiveOAuth2AuthorizedClientService.class); + given(authorizedClientService.loadAuthorizedClient(anyString(), anyString())).willReturn(Mono.empty()); + return authorizedClientService; + } + + } + + @Configuration + @EnableWebFluxSecurity + static class CustomAccessTokenResponseClientsConfig extends MinimalOAuth2ClientConfig { + + @Bean + ReactiveOAuth2AccessTokenResponseClient authorizationCodeAccessTokenResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient refreshTokenTokenAccessResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient clientCredentialsAccessTokenResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient passwordAccessTokenResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient jwtBearerAccessTokenResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient tokenExchangeAccessTokenResponseClient() { + return new MockAccessTokenResponseClient<>(); + } + + } + + @Configuration + @EnableWebFluxSecurity + static class CustomAuthorizedClientProvidersConfig extends MinimalOAuth2ClientConfig { + + @Bean + AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizationCode() { + return spy(new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider()); + } + + @Bean + RefreshTokenReactiveOAuth2AuthorizedClientProvider refreshToken() { + RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>()); + return authorizedClientProvider; + } + + @Bean + ClientCredentialsReactiveOAuth2AuthorizedClientProvider clientCredentials() { + ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>()); + return authorizedClientProvider; + } + + @Bean + PasswordReactiveOAuth2AuthorizedClientProvider password() { + PasswordReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new PasswordReactiveOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>()); + return authorizedClientProvider; + } + + @Bean + JwtBearerReactiveOAuth2AuthorizedClientProvider jwtBearer() { + JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>()); + return authorizedClientProvider; + } + + @Bean + TokenExchangeReactiveOAuth2AuthorizedClientProvider tokenExchange() { + TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>()); + return authorizedClientProvider; + } + + } + + abstract static class OAuth2ClientBaseConfig { + + @Bean + ReactiveClientRegistrationRepository clientRegistrationRepository() { + // @formatter:off + return new InMemoryReactiveClientRegistrationRepository( + CommonOAuth2Provider.GOOGLE.getBuilder("google") + .clientId("google-client-id") + .clientSecret("google-client-secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .build(), + CommonOAuth2Provider.GITHUB.getBuilder("github") + .clientId("github-client-id") + .clientSecret("github-client-secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(), + CommonOAuth2Provider.FACEBOOK.getBuilder("facebook") + .clientId("facebook-client-id") + .clientSecret("facebook-client-secret") + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .build(), + CommonOAuth2Provider.OKTA.getBuilder("okta") + .clientId("okta-client-id") + .clientSecret("okta-client-secret") + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .build(), + ClientRegistration.withRegistrationId("auth0") + .clientName("Auth0") + .clientId("auth0-client-id") + .clientSecret("auth0-client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("user.read", "user.write") + .build()); + // @formatter:on + } + + @Bean + Consumer authorizedClientManagerConsumer() { + return (authorizedClientManager) -> authorizedClientManager + .setContextAttributesMapper((authorizeRequest) -> { + ServerWebExchange exchange = Objects + .requireNonNull(authorizeRequest.getAttribute(ServerWebExchange.class.getName())); + return exchange.getFormData().map((parameters) -> { + String username = parameters.getFirst(OAuth2ParameterNames.USERNAME); + String password = parameters.getFirst(OAuth2ParameterNames.PASSWORD); + + Map attributes = Collections.emptyMap(); + if (StringUtils.hasText(username) && StringUtils.hasText(password)) { + attributes = new HashMap<>(); + attributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); + attributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); + } + + return attributes; + }); + }); + + } + + } + + private static class MockAccessTokenResponseClient + implements ReactiveOAuth2AccessTokenResponseClient { + + @Override + public Mono getTokenResponse(T grantRequest) { + return MOCK_RESPONSE_CLIENT.getTokenResponse(grantRequest); + } + + } + +}