forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable customer httpclient parameter configurration
Signed-off-by: zane-neo <zaniu@amazon.com>
- Loading branch information
Showing
10 changed files
with
213 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
.../main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
|
||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; | ||
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder; | ||
import software.amazon.awssdk.http.SdkHttpConfigurationOption; | ||
import software.amazon.awssdk.utils.AttributeMap; | ||
|
||
import java.time.Duration; | ||
|
||
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; | ||
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND; | ||
|
||
public class RemoteConnectorExecutorFactory { | ||
|
||
private final Settings settings; | ||
|
||
public RemoteConnectorExecutorFactory(Settings settings) { | ||
this.settings = settings; | ||
} | ||
|
||
public RemoteConnectorExecutor create(Connector connector) { | ||
switch (connector.getProtocol()) { | ||
case AWS_SIGV4: | ||
Duration connectionTimeout = Duration.ofMillis(ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.get(settings)); | ||
Duration readTimeout = Duration.ofMillis(ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.get(settings)); | ||
AttributeMap attributeMap = AttributeMap.builder() | ||
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout) | ||
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout) | ||
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.get(settings)) | ||
.build(); | ||
return new AwsConnectorExecutor(connector, new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap)); | ||
case HTTP: | ||
MLHttpClientFactory httpClientFactory = new MLHttpClientFactory(settings); | ||
return new HttpJsonConnectorExecutor(connector, httpClientFactory.createHttpClient()); | ||
default: | ||
throw new IllegalArgumentException("Unknown connector type: " + connector.getProtocol()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16 changes: 16 additions & 0 deletions
16
ml-algorithms/src/main/java/org/opensearch/ml/engine/settings/HttpClientCommonSettings.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package org.opensearch.ml.engine.settings; | ||
|
||
import org.opensearch.common.settings.Setting; | ||
|
||
public class HttpClientCommonSettings { | ||
|
||
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND = | ||
Setting.intSetting("plugins.ml_commons.http_client.connection_timeout.in_millisecond", 1000, 1, Setting.Property.NodeScope, Setting.Property.Final); | ||
|
||
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND = | ||
Setting.intSetting("plugins.ml_commons.http_client.read_timeout.in_millisecond", 3000, 1, Setting.Property.NodeScope, Setting.Property.Final); | ||
|
||
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS = | ||
Setting.intSetting("plugins.ml_commons.http_client.max_total_connections", 20, 20, Setting.Property.NodeScope, Setting.Property.Final); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
.../java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorFactoryTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.mockito.ArgumentCaptor; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.ml.common.connector.AwsConnector; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.connector.ConnectorProtocols; | ||
import org.opensearch.ml.common.connector.HttpConnector; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.when; | ||
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; | ||
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; | ||
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; | ||
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS; | ||
import static org.opensearch.ml.engine.settings.HttpClientCommonSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND; | ||
|
||
public class RemoteConnectorExecutorFactoryTests { | ||
|
||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
private Settings settings = Settings.builder() | ||
.put(ML_COMMONS_HTTP_CLIENT_MAX_TOTAL_CONNECTIONS.getKey(), 30) | ||
.put(ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.getKey(), 1000) | ||
.put(ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.getKey(), 1000) | ||
.build(); | ||
|
||
@Test | ||
public void test_createAWSConnectorExecutor_success() { | ||
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings); | ||
Map<String, String> credential = new HashMap<>(); | ||
credential.put(ACCESS_KEY_FIELD, "test_access_key"); | ||
credential.put(SECRET_KEY_FIELD, "test_secret_key"); | ||
credential.put(REGION_FIELD, "test_region"); | ||
Map<String, String> parameters = new HashMap<>(); | ||
parameters.put(SERVICE_NAME_FIELD, "test_service"); | ||
AwsConnector connector = AwsConnector.awsConnectorBuilder().protocol(ConnectorProtocols.AWS_SIGV4).credential(credential).parameters(parameters).build(); | ||
Assert.assertNotNull(connector); | ||
RemoteConnectorExecutor executor = factory.create(connector); | ||
Assert.assertNotNull(executor); | ||
assertEquals(AwsConnectorExecutor.class, executor.getClass()); | ||
} | ||
|
||
@Test | ||
public void test_createHttpConnectorExecutor_success() { | ||
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings); | ||
Map<String, String> credential = new HashMap<>(); | ||
credential.put(ACCESS_KEY_FIELD, "test_access_key"); | ||
credential.put(SECRET_KEY_FIELD, "test_secret_key"); | ||
Map<String, String> parameters = new HashMap<>(); | ||
parameters.put(SERVICE_NAME_FIELD, "test_service"); | ||
HttpConnector connector = HttpConnector.builder().protocol(ConnectorProtocols.HTTP).credential(credential).parameters(parameters).build(); | ||
Assert.assertNotNull(connector); | ||
RemoteConnectorExecutor executor = factory.create(connector); | ||
Assert.assertNotNull(executor); | ||
assertEquals(HttpJsonConnectorExecutor.class, executor.getClass()); | ||
} | ||
|
||
@Test | ||
public void test_createConnectorExecutor_typeNotFound() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
RemoteConnectorExecutorFactory factory = new RemoteConnectorExecutorFactory(settings); | ||
Connector connector = mock(HttpConnector.class); | ||
when(connector.getProtocol()).thenReturn("http2"); | ||
RemoteConnectorExecutor executor = factory.create(connector); | ||
ArgumentCaptor<IllegalArgumentException> exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); | ||
assertEquals("Unknown connector type: http2", exceptionCaptor.getValue().getMessage()); | ||
} | ||
} |
Oops, something went wrong.