Skip to content

Commit

Permalink
ISPN-15916 Certificate reloading
Browse files Browse the repository at this point in the history
  • Loading branch information
tristantarrant authored and ryanemerson committed Aug 7, 2024
1 parent 10a429d commit b91ab65
Show file tree
Hide file tree
Showing 62 changed files with 1,230 additions and 427 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import static org.infinispan.client.hotrod.impl.Util.wrapBytes;
import static org.infinispan.client.hotrod.logging.Log.HOTROD;

import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.Provider;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -35,6 +37,7 @@
import org.infinispan.client.hotrod.configuration.ClusterConfiguration;
import org.infinispan.client.hotrod.configuration.Configuration;
import org.infinispan.client.hotrod.configuration.ServerConfiguration;
import org.infinispan.client.hotrod.configuration.SslConfiguration;
import org.infinispan.client.hotrod.event.impl.ClientListenerNotifier;
import org.infinispan.client.hotrod.impl.ClientTopology;
import org.infinispan.client.hotrod.impl.ConfigurationProperties;
Expand All @@ -51,15 +54,23 @@
import org.infinispan.client.hotrod.impl.transport.netty.ChannelPool.ChannelEventType;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.CacheConfigurationException;
import org.infinispan.commons.io.FileWatcher;
import org.infinispan.commons.marshall.Marshaller;
import org.infinispan.commons.marshall.WrappedByteArray;
import org.infinispan.commons.marshall.WrappedBytes;
import org.infinispan.commons.util.ProcessorInfo;
import org.infinispan.commons.util.SslContextFactory;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.IdentityCipherSuiteFilter;
import io.netty.handler.ssl.JdkSslContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.RoundRobinDnsAddressResolverGroup;
Expand Down Expand Up @@ -102,6 +113,8 @@ public class ChannelFactory {
private final Set<SocketAddress> failedServers = new HashSet<>();
private final CodecHolder codecHolder;
private AddressResolverGroup<?> dnsResolver;
private SslContext sslContext;
private FileWatcher watcher;

public ChannelFactory(CodecHolder codecHolder) {
this.codecHolder = codecHolder;
Expand All @@ -128,6 +141,7 @@ public void start(Configuration configuration, Marshaller marshaller, ExecutorSe
.ttl(configuration.dnsResolverMinTTL(), configuration.dnsResolverMaxTTL())
.negativeTtl(configuration.dnsResolverNegativeTTL());
this.dnsResolver = new RoundRobinDnsAddressResolverGroup(builder);
this.sslContext = initSslContext();

List<InetSocketAddress> initialServers = new ArrayList<>();
for (ServerConfiguration server : configuration.servers()) {
Expand Down Expand Up @@ -172,6 +186,64 @@ public void start(Configuration configuration, Marshaller marshaller, ExecutorSe
pingServersIgnoreException();
}


private SslContext initSslContext() {
SslConfiguration ssl = configuration.security().ssl();
if (!ssl.enabled()) {
return null;
} else if (ssl.sslContext() == null) {
this.watcher = new FileWatcher();
SslContextBuilder builder = SslContextBuilder.forClient();
try {
if (ssl.keyStoreFileName() != null) {
builder.keyManager(new SslContextFactory()
.keyStoreFileName(ssl.keyStoreFileName())
.keyStoreType(ssl.keyStoreType())
.keyStorePassword(ssl.keyStorePassword())
.keyAlias(ssl.keyAlias())
.classLoader(configuration.classLoader())
.provider(ssl.provider())
.watcher(watcher)
.build().keyManager());
}
if (ssl.trustStoreFileName() != null) {
if ("pem".equalsIgnoreCase(ssl.trustStoreType())) {
builder.trustManager(new File(ssl.trustStoreFileName()));
} else {
builder.trustManager(new SslContextFactory()
.trustStoreFileName(ssl.trustStoreFileName())
.trustStoreType(ssl.trustStoreType())
.trustStorePassword(ssl.trustStorePassword())
.classLoader(configuration.classLoader())
.provider(ssl.provider())
.watcher(watcher)
.build()
.trustManager());
}
}
if (ssl.trustStorePath() != null) {
builder.trustManager(new File(ssl.trustStorePath()));
}
if (ssl.protocol() != null) {
builder.protocols(ssl.protocol());
}
if (ssl.ciphers() != null) {
builder.ciphers(ssl.ciphers());
}
if (ssl.provider() != null) {
Provider provider = SslContextFactory.findProvider(ssl.provider(), SslContext.class.getSimpleName(), "TLS");
builder.sslContextProvider(provider);
}
return builder.build();
} catch (Exception e) {
throw new CacheConfigurationException(e);
}
} else {
return new JdkSslContext(ssl.sslContext(), true, null, IdentityCipherSuiteFilter.INSTANCE,
null, ClientAuth.NONE, null, false);
}
}

public Codec getNegotiatedCodec() {
return codecHolder.getCodec();
}
Expand Down Expand Up @@ -212,7 +284,7 @@ private ChannelPool newPool(SocketAddress address) {
}

public ChannelInitializer createChannelInitializer(SocketAddress address, Bootstrap bootstrap) {
return new ChannelInitializer(bootstrap, address, operationsFactory, configuration, this, topologyInfo.getCluster());
return new ChannelInitializer(bootstrap, address, operationsFactory, configuration, this, topologyInfo.getCluster(), sslContext);
}

protected ChannelPool createChannelPool(Bootstrap bootstrap, ChannelInitializer channelInitializer, SocketAddress address) {
Expand Down Expand Up @@ -250,6 +322,9 @@ private void pingServersIgnoreException() {

public void destroy() {
try {
if (watcher != null) {
watcher.stop();
}
channelPoolMap.values().forEach(ChannelPool::close);
eventLoopGroup.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS).get();
executorService.shutdownNow();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.infinispan.client.hotrod.impl.transport.netty;

import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.Principal;
Expand Down Expand Up @@ -31,20 +30,14 @@
import org.infinispan.client.hotrod.impl.topology.ClusterInfo;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.CacheConfigurationException;
import org.infinispan.commons.util.SaslUtils;
import org.infinispan.commons.util.SslContextFactory;
import org.infinispan.commons.util.Util;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.IdentityCipherSuiteFilter;
import io.netty.handler.ssl.JdkSslContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;

Expand All @@ -59,6 +52,7 @@ class ChannelInitializer extends io.netty.channel.ChannelInitializer<Channel> {
private final ClusterInfo cluster;
private ChannelPool channelPool;
private volatile boolean isFirstPing = true;
private final SslContext sslContext;

private static final Provider[] SECURITY_PROVIDERS;

Expand All @@ -80,13 +74,14 @@ class ChannelInitializer extends io.netty.channel.ChannelInitializer<Channel> {
SECURITY_PROVIDERS = providers.toArray(new Provider[0]);
}

ChannelInitializer(Bootstrap bootstrap, SocketAddress unresolvedAddress, OperationsFactory operationsFactory, Configuration configuration, ChannelFactory channelFactory, ClusterInfo cluster) {
ChannelInitializer(Bootstrap bootstrap, SocketAddress unresolvedAddress, OperationsFactory operationsFactory, Configuration configuration, ChannelFactory channelFactory, ClusterInfo cluster, SslContext sslContext) {
this.bootstrap = bootstrap;
this.unresolvedAddress = unresolvedAddress;
this.operationsFactory = operationsFactory;
this.configuration = configuration;
this.channelFactory = channelFactory;
this.cluster = cluster;
this.sslContext = sslContext;
}

CompletableFuture<Channel> createChannel() {
Expand Down Expand Up @@ -132,55 +127,6 @@ protected void initChannel(Channel channel) throws Exception {

private void initSsl(Channel channel) {
SslConfiguration ssl = configuration.security().ssl();
SslContext sslContext;
if (ssl.sslContext() == null) {
SslContextBuilder builder = SslContextBuilder.forClient();
try {
if (ssl.keyStoreFileName() != null) {
builder.keyManager(new SslContextFactory()
.keyStoreFileName(ssl.keyStoreFileName())
.keyStoreType(ssl.keyStoreType())
.keyStorePassword(ssl.keyStorePassword())
.keyAlias(ssl.keyAlias())
.classLoader(configuration.classLoader())
.provider(ssl.provider())
.getKeyManagerFactory());
}
if (ssl.trustStoreFileName() != null) {
if ("pem".equalsIgnoreCase(ssl.trustStoreType())) {
builder.trustManager(new File(ssl.trustStoreFileName()));
} else {
builder.trustManager(new SslContextFactory()
.trustStoreFileName(ssl.trustStoreFileName())
.trustStoreType(ssl.trustStoreType())
.trustStorePassword(ssl.trustStorePassword())
.classLoader(configuration.classLoader())
.provider(ssl.provider())
.getTrustManagerFactory());
}
}
if (ssl.trustStorePath() != null) {
builder.trustManager(new File(ssl.trustStorePath()));
}
if (ssl.protocol() != null) {
builder.protocols(ssl.protocol());
}
if (ssl.ciphers() != null) {
builder.ciphers(ssl.ciphers());
}
if (ssl.provider() != null) {
Provider provider = SslContextFactory.findProvider(ssl.provider(), SslContext.class.getSimpleName(), "TLS");
builder.sslContextProvider(provider);
}
sslContext = builder.build();
} catch (Exception e) {
throw new CacheConfigurationException(e);
}
} else {
sslContext = new JdkSslContext(ssl.sslContext(), true, null, IdentityCipherSuiteFilter.INSTANCE,
null, ClientAuth.NONE, null, false);
}

SslHandler sslHandler = sslContext.newHandler(channel.alloc(), ssl.sniHostName(), -1);
String sniHostName;
if (cluster != null && cluster.getSniHostName() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public void testLoadTrustStore() {
.keyStorePassword(TestCertificates.KEY_PASSWORD)
.trustStoreFileName(truststoreFileName)
.trustStoreType(TestCertificates.KEYSTORE_TYPE)
.trustStorePassword(TestCertificates.KEY_PASSWORD).getContext();
.trustStorePassword(TestCertificates.KEY_PASSWORD).build().sslContext();

assertNotNull(context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public TestChannelFactory(CodecHolder codecHolder) {

@Override
public ChannelInitializer createChannelInitializer(SocketAddress address, Bootstrap bootstrap) {
return new ChannelInitializer(bootstrap, address, getOperationsFactory(), getConfiguration(), this, null) {
return new ChannelInitializer(bootstrap, address, getOperationsFactory(), getConfiguration(), this, null, null) {
@Override
protected void initChannel(Channel channel) throws Exception {
super.initChannel(channel);
Expand Down
Loading

0 comments on commit b91ab65

Please sign in to comment.