Skip to content

Commit

Permalink
Enable customer httpclient parameter configurration
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Oct 25, 2023
1 parent e15221d commit 0f30c69
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import java.security.AccessController;
Expand All @@ -43,12 +42,15 @@
public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor {

@Getter
private HttpConnector connector;
private final HttpConnector connector;

private final CloseableHttpClient httpClient;
@Setter @Getter
private ScriptService scriptService;

public HttpJsonConnectorExecutor(Connector connector) {
public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) {
this.connector = (HttpConnector)connector;
this.httpClient = httpClient;
}

@Override
Expand Down Expand Up @@ -95,8 +97,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
}

AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
try (CloseableHttpClient httpClient = getHttpClient();
CloseableHttpResponse response = httpClient.execute(request)) {
try (CloseableHttpResponse response = httpClient.execute(request)) {
HttpEntity responseEntity = response.getEntity();
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
Expand All @@ -122,8 +123,4 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
throw new MLException("Fail to execute http connector", e);
}
}

public CloseableHttpClient getHttpClient() {
return MLHttpClientFactory.getCloseableHttpClient();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.opensearch.ml.engine.algorithms.remote;


import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.utils.AttributeMap;

import java.time.Duration;

import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND;

public class RemoteConnectorExecutorFactory {

private final Settings settings;

public RemoteConnectorExecutorFactory(Settings settings) {
this.settings = settings;
}

public RemoteConnectorExecutor create(Connector connector) {
switch (connector.getProtocol()) {
case AWS_SIGV4:
Duration connectionTimeout = Duration.ofMillis(ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.get(settings));
Duration readTimeout = Duration.ofMillis(ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.get(settings));
AttributeMap attributeMap = AttributeMap.builder()
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.get(settings))
.build();
return new AwsConnectorExecutor(connector, new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap));
case HTTP:
MLHttpClientFactory httpClientFactory = new MLHttpClientFactory(settings);
return new HttpJsonConnectorExecutor(connector, httpClientFactory.createHttpClient());
default:
throw new IllegalArgumentException("Unknown connector type: " + connector.getProtocol());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.script.ScriptService;

import java.rmi.Remote;
import java.util.Map;

@Log4j2
Expand All @@ -34,6 +34,11 @@ public class RemoteModel implements Predictable {
public static final String XCONTENT_REGISTRY = "xcontent_registry";

private RemoteConnectorExecutor connectorExecutor;
private final RemoteConnectorExecutorFactory remoteConnectorExecutorFactory;

public RemoteModel(RemoteConnectorExecutorFactory remoteConnectorExecutorFactory) {
this.remoteConnectorExecutorFactory = remoteConnectorExecutorFactory;
}

@VisibleForTesting
RemoteConnectorExecutor getConnectorExecutor() {
Expand Down Expand Up @@ -78,7 +83,7 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
try {
Connector connector = model.getConnector().cloneConnector();
connector.decrypt((credential) -> encryptor.decrypt(credential));
this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
this.connectorExecutor = remoteConnectorExecutorFactory.create(connector);
this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
this.connectorExecutor.setClient((Client) params.get(CLIENT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,35 @@
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.UnsupportedSchemeException;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.impl.conn.DefaultSchemePortResolver;
import org.apache.http.protocol.HttpContext;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.common.settings.Settings;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;

import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND;

@Log4j2
public class MLHttpClientFactory {

public static CloseableHttpClient getCloseableHttpClient() {
return createHttpClient();
private final Settings settings;

public MLHttpClientFactory(Settings settings) {
this.settings = settings;
}

private static CloseableHttpClient createHttpClient() {
public CloseableHttpClient createHttpClient() {
HttpClientBuilder builder = HttpClientBuilder.create();

// Only allow HTTP and HTTPS schemes
Expand All @@ -52,6 +60,13 @@ public boolean isRedirected(HttpRequest request, HttpResponse response, HttpCont
return false;
}
});
builder.setMaxConnTotal(ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.get(settings));
builder.setMaxConnPerRoute(ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.get(settings));
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.get(settings))
.setSocketTimeout(ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.get(settings))
.build();
builder.setDefaultRequestConfig(requestConfig);
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.opensearch.ml.engine.settings;

import org.opensearch.common.settings.Setting;

public class HttpClientCommonSettings {

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND =
Setting.intSetting("plugins.ml_commons.http_client.connection_timeout.in_millisecond", 1000, 1, Setting.Property.NodeScope, Setting.Property.Final);

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND =
Setting.intSetting("plugins.ml_commons.http_client.read_timeout.in_millisecond", 3000, 1, Setting.Property.NodeScope, Setting.Property.Final);

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS =
Setting.intSetting("plugins.ml_commons.http_client.max_total_connections", 20, 20, Setting.Property.NodeScope, Setting.Property.Final);

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
Expand All @@ -34,7 +33,6 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import java.io.IOException;
Expand Down Expand Up @@ -75,7 +73,7 @@ public void invokeRemoteModel_WrongHttpMethod() {
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector, httpClient);
executor.invokeRemoteModel(null, null, null, null);
}

Expand All @@ -88,13 +86,12 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand All @@ -117,8 +114,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Expand All @@ -143,8 +139,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException {
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand All @@ -166,7 +161,7 @@ public void executePredict_TextDocsInput() throws IOException {
.requestBody("{\"input\": ${parameters.input}}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
Expand All @@ -181,7 +176,6 @@ public void executePredict_TextDocsInput() throws IOException {
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.opensearch.ml.engine.algorithms.remote;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorProtocols;
import org.opensearch.ml.common.connector.HttpConnector;

import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS;
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND;

public class RemoteConnectorExecutorFactoryTests {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

private Settings settings = Settings.builder()
.put(ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.getKey(), 30)
.put(ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.getKey(), 1000)
.put(ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.getKey(), 1000)
.build();

@Test
public void test_createAWSConnectorExecutor_success() {
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings);
Map<String, String> credential = new HashMap<>();
credential.put(ACCESS_KEY_FIELD, "test_access_key");
credential.put(SECRET_KEY_FIELD, "test_secret_key");
credential.put(REGION_FIELD, "test_region");
Map<String, String> parameters = new HashMap<>();
parameters.put(SERVICE_NAME_FIELD, "test_service");
AwsConnector connector = AwsConnector.awsConnectorBuilder().protocol(ConnectorProtocols.AWS_SIGV4).credential(credential).parameters(parameters).build();
Assert.assertNotNull(connector);
RemoteConnectorExecutor executor = factory.create(connector);
Assert.assertNotNull(executor);
assertEquals(AwsConnectorExecutor.class, executor.getClass());
}

@Test
public void test_createHttpConnectorExecutor_success() {
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings);
Map<String, String> credential = new HashMap<>();
credential.put(ACCESS_KEY_FIELD, "test_access_key");
credential.put(SECRET_KEY_FIELD, "test_secret_key");
Map<String, String> parameters = new HashMap<>();
parameters.put(SERVICE_NAME_FIELD, "test_service");
HttpConnector connector = HttpConnector.builder().protocol(ConnectorProtocols.HTTP).credential(credential).parameters(parameters).build();
Assert.assertNotNull(connector);
RemoteConnectorExecutor executor = factory.create(connector);
Assert.assertNotNull(executor);
assertEquals(HttpJsonConnectorExecutor.class, executor.getClass());
}

@Test
public void test_createConnectorExecutor_typeNotFound() {
exceptionRule.expect(IllegalArgumentException.class);
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings);
Connector connector = mock(HttpConnector.class);
when(connector.getProtocol()).thenReturn("http2");
RemoteConnectorExecutor executor = factory.create(connector);
ArgumentCaptor<IllegalArgumentException> exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
assertEquals("Unknown connector type: http2", exceptionCaptor.getValue().getMessage());
}
}
Loading

0 comments on commit 0f30c69

Please sign in to comment.