Skip to content

Commit

Permalink
Add OAuth2AuthorizedClientManager Registrar
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrandja authored and sjohnr committed Aug 9, 2023
1 parent 779d472 commit f3d90b3
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,17 +18,40 @@

import java.util.List;

import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.ResolvableType;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
Expand All @@ -48,7 +71,8 @@
* @since 5.1
* @see OAuth2ImportSelector
*/
@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class)
@Import({ OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class,
OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration.class })
final class OAuth2ClientConfiguration {

private static final boolean webMvcPresent;
Expand All @@ -65,8 +89,22 @@ public String[] selectImports(AnnotationMetadata importingClassMetadata) {
if (!webMvcPresent) {
return new String[0];
}
return new String[] { "org.springframework.security.config.annotation.web.configuration."
+ "OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" };
return new String[] {
OAuth2ClientConfiguration.class.getName() + ".OAuth2ClientWebMvcSecurityConfiguration" };
}

}

/**
* @author Joe Grandja
* @since 6.2.0
*/
@Configuration(proxyBeanMethods = false)
static class OAuth2AuthorizedClientManagerConfiguration {

@Bean
OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() {
return new OAuth2AuthorizedClientManagerRegistrar();
}

}
Expand Down Expand Up @@ -160,4 +198,136 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {

}

/**
* A registrar for registering the default {@link OAuth2AuthorizedClientManager} bean
* definition, if not already present.
*
* @author Joe Grandja
* @since 6.2.0
*/
static class OAuth2AuthorizedClientManagerRegistrar
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {

private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();

private BeanFactory beanFactory;

@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, true, true);
if (authorizedClientManagerBeanNames.length != 0) {
return;
}

String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, true, true);
String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true);
if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) {
return;
}

BeanDefinition beanDefinition = BeanDefinitionBuilder
.genericBeanDefinition(DefaultOAuth2AuthorizedClientManager.class)
.addConstructorArgReference(clientRegistrationRepositoryBeanNames[0])
.addConstructorArgReference(authorizedClientRepositoryBeanNames[0])
.addPropertyValue("authorizedClientProvider", getAuthorizedClientProvider()).getBeanDefinition();

registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry),
beanDefinition);
}

@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}

private BeanDefinition getAuthorizedClientProvider() {
ManagedList<Object> authorizedClientProviders = new ManagedList<>();
authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider());
authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider());
authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider());
authorizedClientProviders.add(getPasswordAuthorizedClientProvider());
return BeanDefinitionBuilder.genericBeanDefinition(DelegatingOAuth2AuthorizedClientProvider.class)
.addConstructorArgValue(authorizedClientProviders).getBeanDefinition();
}

private BeanMetadataElement getAuthorizationCodeAuthorizedClientProvider() {
String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, AuthorizationCodeOAuth2AuthorizedClientProvider.class, true,
true);
if (beanNames.length == 1) {
return new RuntimeBeanReference(beanNames[0]);
}

return BeanDefinitionBuilder.genericBeanDefinition(AuthorizationCodeOAuth2AuthorizedClientProvider.class)
.getBeanDefinition();
}

private BeanMetadataElement getRefreshTokenAuthorizedClientProvider() {
String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, RefreshTokenOAuth2AuthorizedClientProvider.class, true,
true);
if (beanNames.length == 1) {
return new RuntimeBeanReference(beanNames[0]);
}

BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
.genericBeanDefinition(RefreshTokenOAuth2AuthorizedClientProvider.class);
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
OAuth2RefreshTokenGrantRequest.class);
beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
resolvableType, true, true);
if (beanNames.length == 1) {
beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
}
return beanDefinitionBuilder.getBeanDefinition();
}

private BeanMetadataElement getClientCredentialsAuthorizedClientProvider() {
String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, ClientCredentialsOAuth2AuthorizedClientProvider.class, true,
true);
if (beanNames.length == 1) {
return new RuntimeBeanReference(beanNames[0]);
}

BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
.genericBeanDefinition(ClientCredentialsOAuth2AuthorizedClientProvider.class);
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
OAuth2ClientCredentialsGrantRequest.class);
beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
resolvableType, true, true);
if (beanNames.length == 1) {
beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
}
return beanDefinitionBuilder.getBeanDefinition();
}

private BeanMetadataElement getPasswordAuthorizedClientProvider() {
String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
(ListableBeanFactory) this.beanFactory, PasswordOAuth2AuthorizedClientProvider.class, true, true);
if (beanNames.length == 1) {
return new RuntimeBeanReference(beanNames[0]);
}

BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
.genericBeanDefinition(PasswordOAuth2AuthorizedClientProvider.class);
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
OAuth2PasswordGrantRequest.class);
beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
resolvableType, true, true);
if (beanNames.length == 1) {
beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
}
return beanDefinitionBuilder.getBeanDefinition();
}

@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}

}

}
Loading

0 comments on commit f3d90b3

Please sign in to comment.