Skip to content

Commit

Permalink
refactor to use apache http5 library
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Mierzwa <dev.maciej.mierzwa@gmail.com>
  • Loading branch information
MaciejMierzwa committed Jan 10, 2024
1 parent b996eb1 commit 67bd6c3
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 151 deletions.
157 changes: 157 additions & 0 deletions src/main/java/com/amazon/dlic/auth/http/saml/HTTPMetadataResolver.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package com.amazon.dlic.auth.http.saml;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Timer;

import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.protocol.HttpClientContext;
import org.apache.hc.core5.http.ClassicHttpResponse;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpException;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.HttpStatus;

import net.shibboleth.utilities.java.support.resolver.ResolverException;
import org.opensaml.saml.metadata.resolver.impl.AbstractReloadingMetadataResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HTTPMetadataResolver extends AbstractReloadingMetadataResolver {
private final Logger log = LoggerFactory.getLogger(HTTPMetadataResolver.class);
private HttpClient httpClient;
private URI metadataURI;
private String cachedMetadataETag;
private String cachedMetadataLastModified;

public HTTPMetadataResolver(final HttpClient client, final String metadataURL) throws ResolverException {
this(null, client, metadataURL);
}

public HTTPMetadataResolver(final Timer backgroundTaskTimer, final HttpClient client, final String metadataURL)
throws ResolverException {
super(backgroundTaskTimer);

if (client == null) {
throw new ResolverException("HTTP client may not be null");
}
httpClient = client;

try {
metadataURI = new URI(metadataURL);
} catch (final URISyntaxException e) {
throw new ResolverException("Illegal URL syntax", e);
}
}

public String getMetadataURI() {
return metadataURI.toASCIIString();
}

@Override
protected void doDestroy() {
if (httpClient instanceof AutoCloseable) {
try {
((AutoCloseable) httpClient).close();
} catch (final Exception e) {
log.error("Error closing HTTP client", e);
}
}
httpClient = null;
metadataURI = null;
cachedMetadataETag = null;
cachedMetadataLastModified = null;

super.doDestroy();
}

@Override
protected String getMetadataIdentifier() {
return metadataURI.toString();
}

@Override
protected byte[] fetchMetadata() throws ResolverException {
final HttpGet httpGet = buildHttpGet();
final HttpClientContext context = HttpClientContext.create();

try {
log.debug("{} Attempting to fetch metadata document from '{}'", getLogPrefix(), metadataURI);
return httpClient.execute(httpGet, context, response -> {
final int httpStatusCode = response.getCode();
if (httpStatusCode == HttpStatus.SC_NOT_MODIFIED) {
log.debug("{} Metadata document from '{}' has not changed since last retrieval", getLogPrefix(), getMetadataURI());
return null;
}
if (httpStatusCode != HttpStatus.SC_OK) {
final String errMsg = "Non-ok status code " + httpStatusCode + " returned from remote metadata source " + metadataURI;
log.error("{} " + errMsg, getLogPrefix());
throw new HttpException(errMsg);
}

processConditionalRetrievalHeaders(response);
try {
return getMetadataBytesFromResponse(response);
} catch (ResolverException e) {
final String errMsg = "Error retrieving metadata from " + metadataURI;
throw new HttpException(errMsg, e);
}
});
} catch (final IOException e) {
final String errMsg = "Error retrieving metadata from " + metadataURI;
log.error("{} {}: {}", getLogPrefix(), errMsg, e.getMessage());
throw new ResolverException(errMsg, e);
}
}

protected HttpGet buildHttpGet() {
final HttpGet getMethod = new HttpGet(getMetadataURI());

if (cachedMetadataETag != null) {
getMethod.setHeader(HttpHeaders.IF_NONE_MATCH, cachedMetadataETag);
}
if (cachedMetadataLastModified != null) {
getMethod.setHeader(HttpHeaders.IF_MODIFIED_SINCE, cachedMetadataLastModified);
}

return getMethod;
}

protected void processConditionalRetrievalHeaders(final ClassicHttpResponse response) {
Header httpHeader = response.getFirstHeader(HttpHeaders.ETAG);
if (httpHeader != null) {
cachedMetadataETag = httpHeader.getValue();
}

httpHeader = response.getFirstHeader(HttpHeaders.LAST_MODIFIED);
if (httpHeader != null) {
cachedMetadataLastModified = httpHeader.getValue();
}
}

protected byte[] getMetadataBytesFromResponse(final ClassicHttpResponse response) throws ResolverException {
log.debug("{} Attempting to extract metadata from response to request for metadata from '{}'", getLogPrefix(), getMetadataURI());
try {
final InputStream ins = response.getEntity().getContent();
return inputstreamToByteArray(ins);
} catch (final IOException e) {
log.error("{} Unable to read response: {}", getLogPrefix(), e.getMessage());
throw new ResolverException("Unable to read response", e);
} finally {
// Make sure entity has been completely consumed.
EntityUtils.consumeQuietly(response.getEntity());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.http.HttpStatus;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
import java.security.PrivilegedExceptionAction;
import java.time.Duration;

import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.io.BasicHttpClientConnectionManager;
import org.apache.hc.client5.http.socket.ConnectionSocketFactory;
import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory;
import org.apache.hc.core5.http.URIScheme;
import org.apache.hc.core5.http.config.Registry;
import org.apache.hc.core5.http.config.RegistryBuilder;

import org.opensearch.SpecialPermission;
import org.opensearch.common.settings.Settings;

import com.amazon.dlic.util.SettingsBasedSSLConfiguratorV4;
import com.amazon.dlic.util.SettingsBasedSSLConfiguratorV5;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import org.opensaml.saml.metadata.resolver.impl.HTTPMetadataResolver;

public class SamlHTTPMetadataResolver extends HTTPMetadataResolver {

Expand All @@ -38,10 +43,9 @@ public class SamlHTTPMetadataResolver extends HTTPMetadataResolver {
}

@Override
@SuppressWarnings("removal")
protected byte[] fetchMetadata() throws ResolverException {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<byte[]>) () -> SamlHTTPMetadataResolver.super.fetchMetadata());
return AccessController.doPrivileged((PrivilegedExceptionAction<byte[]>) SamlHTTPMetadataResolver.super::fetchMetadata);
} catch (PrivilegedActionException e) {

if (e.getCause() instanceof ResolverException) {
Expand All @@ -52,11 +56,10 @@ protected byte[] fetchMetadata() throws ResolverException {
}
}

private static SettingsBasedSSLConfiguratorV4.SSLConfig getSSLConfig(Settings settings, Path configPath) throws Exception {
return new SettingsBasedSSLConfiguratorV4(settings, configPath, "idp").buildSSLConfig();
private static SettingsBasedSSLConfiguratorV5.SSLConfig getSSLConfig(Settings settings, Path configPath) throws Exception {
return new SettingsBasedSSLConfiguratorV5(settings, configPath, "idp").buildSSLConfig();
}

@SuppressWarnings("removal")
private static HttpClient createHttpClient(Settings settings, Path configPath) throws Exception {
try {
final SecurityManager sm = System.getSecurityManager();
Expand Down Expand Up @@ -86,10 +89,16 @@ private static HttpClient createHttpClient0(Settings settings, Path configPath)

builder.useSystemProperties();

SettingsBasedSSLConfiguratorV4.SSLConfig sslConfig = getSSLConfig(settings, configPath);
SettingsBasedSSLConfiguratorV5.SSLConfig sslConfig = getSSLConfig(settings, configPath);

if (sslConfig != null) {
builder.setSSLSocketFactory(sslConfig.toSSLConnectionSocketFactory());
SSLConnectionSocketFactory sslConnectionSocketFactory = sslConfig.toSSLConnectionSocketFactory();
Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory>create()
.register(URIScheme.HTTPS.id, sslConnectionSocketFactory)
.build();

BasicHttpClientConnectionManager connectionManager = new BasicHttpClientConnectionManager(socketFactoryRegistry);
builder.setConnectionManager(connectionManager);
}

return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

package com.amazon.dlic.util;

import java.net.Socket;
import java.nio.file.Path;
import java.security.KeyManagementException;
import java.security.KeyStore;
Expand All @@ -25,22 +24,18 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

import com.google.common.collect.ImmutableList;
import org.apache.http.conn.ssl.DefaultHostnameVerifier;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy;
import org.apache.http.ssl.PrivateKeyDetails;
import org.apache.http.ssl.PrivateKeyStrategy;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.apache.hc.client5.http.ssl.DefaultHostnameVerifier;
import org.apache.hc.client5.http.ssl.NoopHostnameVerifier;
import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.ssl.SSLContexts;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -51,7 +46,7 @@
import static org.opensearch.security.ssl.SecureSSLSettings.SSLSetting.SECURITY_SSL_TRANSPORT_KEYSTORE_PASSWORD;
import static org.opensearch.security.ssl.SecureSSLSettings.SSLSetting.SECURITY_SSL_TRANSPORT_TRUSTSTORE_PASSWORD;

public class SettingsBasedSSLConfiguratorV4 {
public class SettingsBasedSSLConfiguratorV5 {
private static final Logger log = LogManager.getLogger(SettingsBasedSSLConfigurator.class);

public static final String CERT_ALIAS = "cert_alias";
Expand Down Expand Up @@ -95,14 +90,14 @@ public class SettingsBasedSSLConfiguratorV4 {
private String effectiveKeyAlias;
private List<String> effectiveTruststoreAliases;

public SettingsBasedSSLConfiguratorV4(Settings settings, Path configPath, String settingsKeyPrefix, String clientName) {
public SettingsBasedSSLConfiguratorV5(Settings settings, Path configPath, String settingsKeyPrefix, String clientName) {
this.settings = settings;
this.configPath = configPath;
this.settingsKeyPrefix = normalizeSettingsKeyPrefix(settingsKeyPrefix);
this.clientName = clientName != null ? clientName : this.settingsKeyPrefix;
}

public SettingsBasedSSLConfiguratorV4(Settings settings, Path configPath, String settingsKeyPrefix) {
public SettingsBasedSSLConfiguratorV5(Settings settings, Path configPath, String settingsKeyPrefix) {
this(settings, configPath, settingsKeyPrefix, null);
}

Expand Down Expand Up @@ -203,20 +198,16 @@ private void configureWithSettings() throws SSLConfigException, NoSuchAlgorithmE
if (enableSslClientAuth) {
if (effectiveKeystore != null) {
try {
sslContextBuilder.loadKeyMaterial(effectiveKeystore, effectiveKeyPassword, new PrivateKeyStrategy() {

@Override
public String chooseAlias(Map<String, PrivateKeyDetails> aliases, Socket socket) {
if (aliases == null || aliases.isEmpty()) {
return effectiveKeyAlias;
}

if (effectiveKeyAlias == null || effectiveKeyAlias.isEmpty()) {
return aliases.keySet().iterator().next();
}

sslContextBuilder.loadKeyMaterial(effectiveKeystore, effectiveKeyPassword, (aliases, socket) -> {
if (aliases == null || aliases.isEmpty()) {
return effectiveKeyAlias;
}

if (effectiveKeyAlias == null || effectiveKeyAlias.isEmpty()) {
return aliases.keySet().iterator().next();
}

return effectiveKeyAlias;
});
} catch (UnrecoverableKeyException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -470,10 +461,6 @@ public HostnameVerifier getHostnameVerifier() {
return hostnameVerifier;
}

public SSLIOSessionStrategy toSSLIOSessionStrategy() {
return new SSLIOSessionStrategy(sslContext, supportedProtocols, supportedCipherSuites, hostnameVerifier);
}

public SSLConnectionSocketFactory toSSLConnectionSocketFactory() {
return new SSLConnectionSocketFactory(sslContext, supportedProtocols, supportedCipherSuites, hostnameVerifier);
}
Expand Down
Loading

0 comments on commit 67bd6c3

Please sign in to comment.