Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jan 30, 2024
1 parent 6e29e86 commit 248c670
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ public interface Predictable {
* @param mlInput input data
* @return predicted results
*/
default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
}
default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
}

default void predict(MLInput mlInput, MLTask mlTask, ActionListener<MLTaskResponse> actionListener) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,25 @@ public AwsConnectorExecutor(Connector connector) {
this.httpClient = MLHttpClientFactory.getAsyncHttpClient();
}


@Override
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, Map<Integer, ModelTensors> tensorOutputs, WrappedCountDownLatch countDownLatch, ActionListener<List<ModelTensors>> actionListener) {
public void invokeRemoteModel(
MLInput mlInput,
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST, actionListener);
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
MLHttpClientFactory.validateIp(request.getUri().getHost());
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService)
)
.build();
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException exception) {
Expand All @@ -81,7 +89,6 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
}
}


private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
String accessKey = connector.getAccessKey();
String secretKey = connector.getSecretKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
Expand Down Expand Up @@ -225,7 +224,12 @@ public static SdkHttpFullRequest signRequest(
return signer.sign(request, params);
}

public static SdkHttpFullRequest buildSdkRequest(Connector connector, Map<String, String> parameters, String payload, SdkHttpMethod method, ActionListener<List<ModelTensors>> actionListener) {
public static SdkHttpFullRequest buildSdkRequest(
Connector connector,
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) {
String endpoint = connector.getPredictEndpoint(parameters);
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
Expand All @@ -234,6 +238,10 @@ public static SdkHttpFullRequest buildSdkRequest(Connector connector, Map<String
} else {
requestBody = RequestBody.empty();
}
if (requestBody.optionalContentLength().isEmpty()) {
log.error("Content length is empty. Aborting request to remote model");
throw new IllegalArgumentException("Content length is empty. Aborting request to remote model");
}
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -63,17 +64,24 @@ public HttpJsonConnectorExecutor(Connector connector) {
}

@Override
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, Map<Integer, ModelTensors> tensorOutputs, WrappedCountDownLatch countDownLatch, ActionListener<List<ModelTensors>> actionListener) {
public void invokeRemoteModel(
MLInput mlInput,
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
SdkHttpFullRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST, actionListener);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
case "GET":
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET, actionListener);
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
default:
Expand All @@ -83,7 +91,9 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
.builder()
.request(request)
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService)
)
.build();
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@

package org.opensearch.ml.engine.algorithms.remote;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.action.ActionListener;
Expand All @@ -19,20 +27,13 @@
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.SdkHttpFullResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

@Log4j2
public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler {
@Getter
Expand Down Expand Up @@ -85,7 +86,13 @@ public void onStream(Publisher<ByteBuffer> stream) {
@Override
public void onError(Throwable error) {
log.error(error.getMessage(), error);
actionListener.onFailure(new OpenSearchStatusException("Error on communication with remote model: " + error.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
actionListener
.onFailure(
new OpenSearchStatusException(
"Error on communication with remote model: " + error.getMessage(),
RestStatus.INTERNAL_SERVER_ERROR
)
);
}

private void processResponse(Integer statusCode, String body, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs)
Expand Down Expand Up @@ -116,49 +123,82 @@ private List<ModelTensors> reOrderTensorResponses(Map<Integer, ModelTensors> ten
}

protected class MLResponseSubscriber implements Subscriber<ByteBuffer> {
private Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
s.request(Long.MAX_VALUE);
}
private Subscription subscription;

@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(ByteBuffer byteBuffer) {
responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
subscription.request(Long.MAX_VALUE);
}

@Override
public void onNext(ByteBuffer byteBuffer) {
responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
subscription.request(Long.MAX_VALUE);
public void onError(Throwable t) {
countDownLatch.getCountDownLatch().countDown();
log
.error(
"Error on receiving response body from remote: {}",
t instanceof NullPointerException ? "NullPointerException" : t.getMessage(),
t
);
errorMsg
.add(
"Error on receiving response body from remote: "
+ (t instanceof NullPointerException ? "NullPointerException" : t.getMessage())
);
if (countDownLatch.getCountDownLatch().getCount() == 0) {
actionListener
.onFailure(
new OpenSearchStatusException(
"Error on receiving response body from remote: " + String.join(",", errorMsg),
RestStatus.INTERNAL_SERVER_ERROR
)
);
} else {
log.debug("Not all responses received, left response count is: " + countDownLatch.getCountDownLatch().getCount());
}
@Override public void onError(Throwable t) {
}

@Override
public void onComplete() {
try {
String fullResponseBody = responseBody.toString();
processResponse(statusCode, fullResponseBody, parameters, tensorOutputs);
countDownLatch.getCountDownLatch().countDown();
if (countDownLatch.getCountDownLatch().getCount() == 0) {
log.debug("All responses received, calling action listener to return final results.");
actionListener.onResponse(reOrderTensorResponses(tensorOutputs));
}
} catch (Throwable e) {
countDownLatch.getCountDownLatch().countDown();
log.error("Error on receiving response body from remote: {}", t instanceof NullPointerException ? "NullPointerException" : t.getMessage(), t);
errorMsg.add("Error on receiving response body from remote: " + (t instanceof NullPointerException ? "NullPointerException" : t.getMessage()));
log
.error(
"Error on processing response from remote: {}",
e instanceof NullPointerException ? "NullPointerException" : e.getMessage(),
e
);
errorMsg
.add(
"Error on receiving response from remote: "
+ (e instanceof NullPointerException ? "NullPointerException" : e.getMessage())
);
if (countDownLatch.getCountDownLatch().getCount() == 0) {
actionListener.onFailure(new OpenSearchStatusException("Error on receiving response body from remote: " + String.join(",", errorMsg), RestStatus.INTERNAL_SERVER_ERROR));
actionListener
.onFailure(
new OpenSearchStatusException(
"Error on receiving response from remote: " + String.join(",", errorMsg),
RestStatus.INTERNAL_SERVER_ERROR
)
);
} else {
log.debug("Not all responses received, left response count is: " + countDownLatch.getCountDownLatch().getCount());
}
}

@Override
public void onComplete() {
try {
String fullResponseBody = responseBody.toString();
processResponse(statusCode, fullResponseBody, parameters, tensorOutputs);
countDownLatch.getCountDownLatch().countDown();
if (countDownLatch.getCountDownLatch().getCount() == 0) {
log.debug("All responses received, calling action listener to return final results.");
actionListener.onResponse(reOrderTensorResponses(tensorOutputs));
}
} catch (Throwable e) {
countDownLatch.getCountDownLatch().countDown();
log.error("Error on processing response from remote: {}", e instanceof NullPointerException ? "NullPointerException" : e.getMessage(), e);
errorMsg.add("Error on receiving response from remote: " + (e instanceof NullPointerException ? "NullPointerException" : e.getMessage()));
if (countDownLatch.getCountDownLatch().getCount() == 0) {
actionListener.onFailure(new OpenSearchStatusException("Error on receiving response from remote: " + String.join(",", errorMsg), RestStatus.INTERNAL_SERVER_ERROR));
} else {
log.debug("Not all responses received, left response count is: " + countDownLatch.getCountDownLatch().getCount());
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,18 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
modelTensors, new WrappedCountDownLatch(sequence++, countDownLatch) , tensorActionListener);
modelTensors,
new WrappedCountDownLatch(sequence++, countDownLatch),
tensorActionListener
);
}
} else {
preparePayloadAndInvokeRemoteModel(mlInput, modelTensors, new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener);
preparePayloadAndInvokeRemoteModel(
mlInput,
modelTensors,
new WrappedCountDownLatch(0, new CountDownLatch(1)),
tensorActionListener
);
}
}

Expand Down Expand Up @@ -104,7 +112,12 @@ default void setRateLimiter(TokenBucket rateLimiter) {}

default void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {}

default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, Map<Integer, ModelTensors> tensorOutputs, WrappedCountDownLatch countDownLatch, ActionListener<List<ModelTensors>> actionListener) {
default void preparePayloadAndInvokeRemoteModel(
MLInput mlInput,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
Connector connector = getConnector();

Map<String, String> parameters = new HashMap<>();
Expand Down Expand Up @@ -142,5 +155,12 @@ && getUserRateLimiterMap().get(user.getName()) != null
}
}

void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, Map<Integer, ModelTensors> tensorOutputs, WrappedCountDownLatch countDownLatch, ActionListener<List<ModelTensors>> actionListener);
void invokeRemoteModel(
MLInput mlInput,
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ActionListener<List<ModelTensors>> actionListener
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.ml.engine.algorithms.remote;

import java.util.List;
import java.util.Map;

import org.opensearch.client.Client;
Expand Down Expand Up @@ -59,7 +58,10 @@ public MLOutput predict(MLInput mlInput, MLModel model) {
@Override
public void predict(MLInput mlInput, MLTask mlTask, ActionListener<MLTaskResponse> actionListener) {
if (!isModelReady()) {
actionListener.onFailure(new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy"));
actionListener
.onFailure(
new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy")
);
}
try {
connectorExecutor.executePredict(mlInput, actionListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

package org.opensearch.ml.engine.algorithms.remote;

import java.util.concurrent.CountDownLatch;

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.concurrent.CountDownLatch;

@Data
@AllArgsConstructor
public class WrappedCountDownLatch {
Expand Down
Loading

0 comments on commit 248c670

Please sign in to comment.