Skip to content

Commit

Permalink
Implementing retry for remote connector to mitigate throttling issue (o…
Browse files Browse the repository at this point in the history
…pensearch-project#2462)

* use retryable action; execution context

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change to groupedActionListener

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix group

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* retry policy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* base time

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* retry option, cluster settings

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* nit

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* lint

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change interface to class

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix ut due to code change

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* license header

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add ut

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix core interface

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* license header

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* use exception holder

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add max retry times settings

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix typo

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* nit

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change the order to avoid misleading log

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* license header

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* move settings to connector

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* remove settings

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add retry_backoff_policy setting

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* changes for comments

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix retry times

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* make the error handling more neat in MLSdkAsyncHttpResponseHandler

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change to SageMakerThrottlingException

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* use enum for retry backoff policy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix seconds to milliseconds in equal jitter policy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* disable retry by default

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

---------

Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws authored Jun 6, 2024
1 parent 865a424 commit 399825f
Show file tree
Hide file tree
Showing 19 changed files with 940 additions and 441 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -16,6 +17,9 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -26,45 +30,88 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable {
public static final String MAX_CONNECTION_FIELD = "max_connection";
public static final String CONNECTION_TIMEOUT_FIELD = "connection_timeout";
public static final String READ_TIMEOUT_FIELD = "read_timeout";
public static final String RETRY_BACKOFF_MILLIS_FIELD = "retry_backoff_millis";
public static final String RETRY_TIMEOUT_SECONDS_FIELD = "retry_timeout_seconds";
public static final String MAX_RETRY_TIMES_FIELD = "max_retry_times";
public static final String RETRY_BACKOFF_POLICY_FIELD = "retry_backoff_policy";

public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30);
public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);

public static final Integer RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200;
public static final Integer RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30;
public static final Integer MAX_RETRY_TIMES_DEFAULT_VALUE = 0;
public static final RetryBackoffPolicy RETRY_BACKOFF_POLICY_DEFAULT_VALUE = RetryBackoffPolicy.CONSTANT;
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_RETRY = Version.V_2_15_0;
private Integer maxConnections;
private Integer connectionTimeout;
private Integer readTimeout;
private Integer retryBackoffMillis;
private Integer retryTimeoutSeconds;
private Integer maxRetryTimes;
private RetryBackoffPolicy retryBackoffPolicy;

@Builder(toBuilder = true)
public ConnectorClientConfig(
Integer maxConnections,
Integer connectionTimeout,
Integer readTimeout
Integer readTimeout,
Integer retryBackoffMillis,
Integer retryTimeoutSeconds,
Integer maxRetryTimes,
RetryBackoffPolicy retryBackoffPolicy
) {
this.maxConnections = maxConnections;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;

this.retryBackoffMillis = retryBackoffMillis;
this.retryTimeoutSeconds = retryTimeoutSeconds;
this.maxRetryTimes = maxRetryTimes;
this.retryBackoffPolicy = retryBackoffPolicy;
}

public ConnectorClientConfig(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
this.maxConnections = input.readOptionalInt();
this.connectionTimeout = input.readOptionalInt();
this.readTimeout = input.readOptionalInt();
if(streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) {
this.retryBackoffMillis = input.readOptionalInt();
this.retryTimeoutSeconds = input.readOptionalInt();
this.maxRetryTimes = input.readOptionalInt();
if (input.readBoolean()) {
this.retryBackoffPolicy = RetryBackoffPolicy.from(input.readString());
}
}
}

public ConnectorClientConfig() {
this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
this.retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE;
this.retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE;
this.maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE;
this.retryBackoffPolicy = RETRY_BACKOFF_POLICY_DEFAULT_VALUE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {

Version streamOutputVersion = out.getVersion();
out.writeOptionalInt(maxConnections);
out.writeOptionalInt(connectionTimeout);
out.writeOptionalInt(readTimeout);
if(streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)){
out.writeOptionalInt(retryBackoffMillis);
out.writeOptionalInt(retryTimeoutSeconds);
out.writeOptionalInt(maxRetryTimes);
if (Objects.nonNull(retryBackoffPolicy)) {
out.writeBoolean(true);
out.writeString(retryBackoffPolicy.name());
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand All @@ -79,6 +126,18 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (readTimeout != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeout);
}
if (retryBackoffMillis != null) {
builder.field(RETRY_BACKOFF_MILLIS_FIELD, retryBackoffMillis);
}
if (retryTimeoutSeconds != null) {
builder.field(RETRY_TIMEOUT_SECONDS_FIELD, retryTimeoutSeconds);
}
if (maxRetryTimes != null) {
builder.field(MAX_RETRY_TIMES_FIELD, maxRetryTimes);
}
if (retryBackoffPolicy != null) {
builder.field(RETRY_BACKOFF_POLICY_FIELD, retryBackoffPolicy.name().toLowerCase(Locale.ROOT));
}
return builder.endObject();
}

Expand All @@ -88,9 +147,13 @@ public static ConnectorClientConfig fromStream(StreamInput in) throws IOExceptio
}

public static ConnectorClientConfig parse(XContentParser parser) throws IOException {
Integer maxConnections = null;
Integer connectionTimeout = null;
Integer readTimeout = null;
Integer maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
Integer connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
Integer readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
Integer retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE;
Integer retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE;
Integer maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE;
RetryBackoffPolicy retryBackoffPolicy = RETRY_BACKOFF_POLICY_DEFAULT_VALUE;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -107,6 +170,18 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept
case READ_TIMEOUT_FIELD:
readTimeout = parser.intValue();
break;
case RETRY_BACKOFF_MILLIS_FIELD:
retryBackoffMillis = parser.intValue();
break;
case RETRY_TIMEOUT_SECONDS_FIELD:
retryTimeoutSeconds = parser.intValue();
break;
case MAX_RETRY_TIMES_FIELD:
maxRetryTimes = parser.intValue();
break;
case RETRY_BACKOFF_POLICY_FIELD:
retryBackoffPolicy = RetryBackoffPolicy.from(parser.text());
break;
default:
parser.skipChildren();
break;
Expand All @@ -116,6 +191,10 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept
.maxConnections(maxConnections)
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.retryBackoffMillis(retryBackoffMillis)
.retryTimeoutSeconds(retryTimeoutSeconds)
.maxRetryTimes(maxRetryTimes)
.retryBackoffPolicy(retryBackoffPolicy)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector;

import java.util.Locale;

public enum RetryBackoffPolicy {
CONSTANT,
EXPONENTIAL_EQUAL_JITTER,
EXPONENTIAL_FULL_JITTER;

public static RetryBackoffPolicy from(String value) {
try {
return RetryBackoffPolicy.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Unsupported retry backoff policy");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public static Map<String, String> convertScriptStringToJsonString(Map<String, Ob
Map<String, String> parameterStringMap = new HashMap<>();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.get("parameters");
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.getOrDefault("parameters", Map.of());
for (String key : parametersMap.keySet()) {
if (parametersMap.get(key) instanceof String) {
parameterStringMap.put(key, (String) parametersMap.get(key));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public void toXContent_InternalConnector() throws IOException {
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," +
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}}",
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," +
"\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -23,6 +25,10 @@ public void writeTo_ReadFromStream() throws IOException {
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

BytesStreamOutput output = new BytesStreamOutput();
Expand All @@ -32,25 +38,71 @@ public void writeTo_ReadFromStream() throws IOException {
Assert.assertEquals(config, readConfig);
}

@Test
public void writeTo_ReadFromStream_nullValues() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.build();

BytesStreamOutput output = new BytesStreamOutput();
config.writeTo(output);
ConnectorClientConfig readConfig = new ConnectorClientConfig(output.bytes().streamInput());

Assert.assertEquals(config, readConfig);
}

@Test
public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

BytesStreamOutput output = new BytesStreamOutput();
output.setVersion(Version.V_2_14_0);
config.writeTo(output);
StreamInput input = output.bytes().streamInput();
input.setVersion(Version.V_2_14_0);
ConnectorClientConfig readConfig = ConnectorClientConfig.fromStream(input);

Assert.assertEquals(Integer.valueOf(10),readConfig.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000),readConfig.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000),readConfig.getReadTimeout());
Assert.assertNull(readConfig.getRetryBackoffMillis());
Assert.assertNull(readConfig.getRetryTimeoutSeconds());
Assert.assertNull(readConfig.getMaxRetryTimes());
Assert.assertNull(readConfig.getRetryBackoffPolicy());
}

@Test
public void toXContent() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}";
String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
Assert.assertEquals(expectedJson, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}";
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand All @@ -60,6 +112,22 @@ public void parse() throws IOException {
Assert.assertEquals(Integer.valueOf(10), config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000), config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(123), config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(456), config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(789), config.getMaxRetryTimes());
Assert.assertEquals(RetryBackoffPolicy.CONSTANT, config.getRetryBackoffPolicy());
}

@Test
public void parse_whenMalformedBackoffPolicy_thenFail() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"test\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> ConnectorClientConfig.parse(parser));
Assert.assertEquals("Unsupported retry backoff policy", exception.getMessage());
}

@Test
Expand All @@ -69,6 +137,23 @@ public void testDefaultValues() {
Assert.assertNull(config.getMaxConnections());
Assert.assertNull(config.getConnectionTimeout());
Assert.assertNull(config.getReadTimeout());
Assert.assertNull(config.getRetryBackoffMillis());
Assert.assertNull(config.getRetryTimeoutSeconds());
Assert.assertNull(config.getMaxRetryTimes());
Assert.assertNull(config.getRetryBackoffPolicy());
}

@Test
public void testDefaultValuesInitByNewInstance() {
ConnectorClientConfig config = new ConnectorClientConfig();

Assert.assertEquals(Integer.valueOf(30),config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(30000),config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(30000),config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(200),config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(30),config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(0),config.getMaxRetryTimes());
Assert.assertEquals(RetryBackoffPolicy.CONSTANT, config.getRetryBackoffPolicy());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class HttpConnectorTest {
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," +
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}";
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," +
"\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

@Before
public void setUp() {
Expand Down Expand Up @@ -293,7 +294,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod
Map<String, String> credential = new HashMap<>();
credential.put("key", "test_key_value");

ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000);
ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);

HttpConnector connector = HttpConnector.builder()
.name("test_connector_name")
Expand Down
Loading

0 comments on commit 399825f

Please sign in to comment.