diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 8454f1b519..bb900fe99e 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -63,12 +63,13 @@ dependencies { } } - implementation platform('software.amazon.awssdk:bom:2.21.15') + implementation platform('software.amazon.awssdk:bom:2.25.40') implementation 'software.amazon.awssdk:auth' implementation 'software.amazon.awssdk:apache-client' implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'com.jayway.jsonpath:json-path:2.9.0' implementation group: 'org.json', name: 'json', version: '20231013' + implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40' } lombok { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java index 38c5889c78..5ee1e4d7a1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java @@ -7,9 +7,11 @@ import java.util.Map; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; /** @@ -31,7 +33,13 @@ public interface Predictable { * @param mlInput input data * @return predicted results */ - MLOutput predict(MLInput mlInput); + default MLOutput predict(MLInput mlInput) { + throw new IllegalStateException("Method is not implemented"); + } + + default void asyncPredict(MLInput mlInput, ActionListener actionListener) { + actionListener.onFailure(new IllegalStateException("Method is not implemented")); + } /** * Init model (load model into memory) with ML model content and params. diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index ea53b2c08e..916e486654 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -5,25 +5,19 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; import static software.amazon.awssdk.http.SdkHttpMethod.POST; -import java.io.BufferedReader; -import java.io.InputStreamReader; -import java.net.URI; -import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; -import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; import org.opensearch.common.util.TokenBucket; -import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; @@ -31,20 +25,16 @@ import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.http.AbortableInputStream; -import software.amazon.awssdk.http.HttpExecuteRequest; -import software.amazon.awssdk.http.HttpExecuteResponse; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.SdkHttpConfigurationOption; +import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; import software.amazon.awssdk.http.SdkHttpFullRequest; -import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @Log4j2 @ConnectorExecutor(AWS_SIGV4) @@ -52,7 +42,6 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private AwsConnector connector; - private SdkHttpClient httpClient; @Setter @Getter private ScriptService scriptService; @@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private MLGuard mlGuard; - public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { - this.connector = (AwsConnector) connector; - this.httpClient = httpClient; - } + private SdkAsyncHttpClient httpClient; public AwsConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (AwsConnector) connector; - Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout()); - try ( - AttributeMap attributeMap = AttributeMap - .builder() - .put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout) - .put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout) - .put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections()) - .build() - ) { - log - .info( - "Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}", - connectionTimeout, - readTimeout, - super.getConnectorClientConfig().getMaxConnections() - ); - this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap); - } catch (RuntimeException e) { - log.error("Error initializing AWS connector HTTP client.", e); - throw e; - } catch (Throwable e) { - log.error("Error initializing AWS connector HTTP client.", e); - throw new MLException(e); - } + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @SuppressWarnings("removal") @Override - public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { + public void invokeRemoteModel( + MLInput mlInput, + Map parameters, + String payload, + Map tensorOutputs, + ExecutionContext countDownLatch, + ActionListener> actionListener + ) { try { - String endpoint = connector.getPredictEndpoint(parameters); - RequestBody requestBody = RequestBody.fromString(payload); - - SdkHttpFullRequest.Builder builder = SdkHttpFullRequest - .builder() - .method(POST) - .uri(URI.create(endpoint)) - .contentStreamProvider(requestBody.contentStreamProvider()); - Map headers = connector.getDecryptedHeaders(); - if (headers != null) { - for (String key : headers.keySet()) { - builder.putHeader(key, headers.get(key)); - } - } - SdkHttpFullRequest request = builder.build(); - HttpExecuteRequest executeRequest = HttpExecuteRequest + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); + AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) - .contentStreamProvider(request.contentStreamProvider().orElse(null)) + .requestContentPublisher(new SimpleHttpContentPublisher(request)) + .responseHandler( + new MLSdkAsyncHttpResponseHandler( + countDownLatch, + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + mlGuard + ) + ) .build(); - - HttpExecuteResponse response = AccessController - .doPrivileged((PrivilegedExceptionAction) () -> httpClient.prepareRequest(executeRequest).call()); - int statusCode = response.httpResponse().statusCode(); - - AbortableInputStream body = null; - if (response.responseBody().isPresent()) { - body = response.responseBody().get(); - } - - StringBuilder responseBuilder = new StringBuilder(); - if (body != null) { - try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - responseBuilder.append(line); - } - } - } else { - throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); - } - String modelResponse = responseBuilder.toString(); - if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) { - throw new IllegalArgumentException("guardrails triggered for LLM output"); - } - if (statusCode < 200 || statusCode >= 300) { - throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); - } - - ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCode); - tensorOutputs.add(tensors); + AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); } catch (RuntimeException exception) { log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); - throw exception; + actionListener.onFailure(exception); } catch (Throwable e) { log.error("Failed to execute predict in aws connector", e); - throw new MLException("Fail to execute predict in aws connector", e); + actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 49e6ef7d69..a6181e1b2f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -15,6 +15,8 @@ import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import java.io.IOException; +import java.net.URI; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -34,6 +36,7 @@ import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.script.ScriptService; @@ -46,7 +49,9 @@ import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.signer.Aws4Signer; import software.amazon.awssdk.auth.signer.params.Aws4SignerParams; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.regions.Region; @Log4j2 @@ -179,11 +184,15 @@ public static ModelTensors processOutput( String modelResponse, Connector connector, ScriptService scriptService, - Map parameters + Map parameters, + MLGuard mlGuard ) throws IOException { if (modelResponse == null) { throw new IllegalArgumentException("model response is null"); } + if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) { + throw new IllegalArgumentException("guardrails triggered for LLM output"); + } List modelTensors = new ArrayList<>(); Optional predictAction = connector.findPredictAction(); if (predictAction.isEmpty()) { @@ -252,4 +261,42 @@ public static SdkHttpFullRequest signRequest( return signer.sign(request, params); } + + public static SdkHttpFullRequest buildSdkRequest( + Connector connector, + Map parameters, + String payload, + SdkHttpMethod method + ) { + String charset = parameters.getOrDefault("charset", "UTF-8"); + RequestBody requestBody; + if (payload != null) { + requestBody = RequestBody.fromString(payload, Charset.forName(charset)); + } else { + requestBody = RequestBody.empty(); + } + if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) { + log.error("Content length is 0. Aborting request to remote model"); + throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); + } + String endpoint = connector.getPredictEndpoint(parameters); + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest + .builder() + .method(method) + .uri(URI.create(endpoint)) + .contentStreamProvider(requestBody.contentStreamProvider()); + Map headers = connector.getDecryptedHeaders(); + if (headers != null) { + for (String key : headers.keySet()) { + builder.putHeader(key, headers.get(key)); + } + } + if (builder.matchingHeaders("Content-Type").isEmpty()) { + builder.putHeader("Content-Type", "application/json"); + } + if (builder.matchingHeaders("Content-Length").isEmpty()) { + builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString()); + } + return builder.build(); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java new file mode 100644 index 0000000000..66c828bead --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java @@ -0,0 +1,32 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import lombok.AllArgsConstructor; +import lombok.Data; + +/** + * This class encapsulates several parameters that are used in a split-batch request case. + * A batch request is that in neural-search side multiple fields are send in one request to ml-commons, + * but the remote model doesn't accept list of string inputs so in ml-commons the request needs split. + * sequence is used to identify the index of the split request. + * countDownLatch is used to wait for all the split requests to finish. + * exceptionHolder is used to hold any exception thrown in a split-batch request. + */ +@Data +@AllArgsConstructor +public class ExecutionContext { + // Should never be null + private int sequence; + private CountDownLatch countDownLatch; + // This is to hold any exception thrown in a split-batch request + private AtomicReference exceptionHolder; +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 080e8d7553..b92a8d57d4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -5,29 +5,22 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; +import static software.amazon.awssdk.http.SdkHttpMethod.GET; +import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import java.net.URL; import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.CompletableFuture; -import org.apache.http.HttpEntity; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.client.methods.HttpGet; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.methods.HttpUriRequest; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.util.EntityUtils; -import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; import org.opensearch.common.util.TokenBucket; -import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; @@ -41,6 +34,10 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @Log4j2 @ConnectorExecutor(HTTP) @@ -65,98 +62,74 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Getter private MLGuard mlGuard; - private CloseableHttpClient httpClient; + private SdkAsyncHttpClient httpClient; public HttpJsonConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (HttpConnector) connector; - this.httpClient = MLHttpClientFactory - .getCloseableHttpClient( - super.getConnectorClientConfig().getConnectionTimeout(), - super.getConnectorClientConfig().getReadTimeout(), - super.getConnectorClientConfig().getMaxConnections() - ); - } - - public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) { - this(connector); - this.httpClient = httpClient; + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @SuppressWarnings("removal") @Override - public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { + public void invokeRemoteModel( + MLInput mlInput, + Map parameters, + String payload, + Map tensorOutputs, + ExecutionContext countDownLatch, + ActionListener> actionListener + ) { try { - AtomicReference responseRef = new AtomicReference<>(""); - AtomicReference statusCodeRef = new AtomicReference<>(); - - HttpUriRequest request; + SdkHttpFullRequest request; switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { case "POST": - try { - String predictEndpoint = connector.getPredictEndpoint(parameters); - request = new HttpPost(predictEndpoint); - String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8"; - HttpEntity entity = new StringEntity(payload, charset); - ((HttpPost) request).setEntity(entity); - } catch (Exception e) { - throw new MLException("Failed to create http request for remote model", e); - } + log.debug("original payload to remote model: " + payload); + validateHttpClientParameters(parameters); + request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); break; case "GET": - try { - request = new HttpGet(connector.getPredictEndpoint(parameters)); - } catch (Exception e) { - throw new MLException("Failed to create http request for remote model", e); - } + validateHttpClientParameters(parameters); + request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET); break; default: throw new IllegalArgumentException("unsupported http method"); } - - Map headers = connector.getDecryptedHeaders(); - boolean hasContentTypeHeader = false; - if (headers != null) { - for (String key : headers.keySet()) { - request.addHeader(key, (String) headers.get(key)); - if (key.toLowerCase().equals("Content-Type")) { - hasContentTypeHeader = true; - } - } - } - if (!hasContentTypeHeader) { - request.addHeader("Content-Type", "application/json"); - } - - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - try (CloseableHttpResponse response = httpClient.execute(request)) { - HttpEntity responseEntity = response.getEntity(); - String responseBody = EntityUtils.toString(responseEntity); - EntityUtils.consume(responseEntity); - responseRef.set(responseBody); - statusCodeRef.set(response.getStatusLine().getStatusCode()); - } - return null; - }); - String modelResponse = responseRef.get(); - if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) { - throw new IllegalArgumentException("guardrails triggered for LLM output"); - } - Integer statusCode = statusCodeRef.get(); - if (statusCode < 200 || statusCode >= 300) { - throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); - } - - ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCode); - tensorOutputs.add(tensors); + AsyncExecuteRequest executeRequest = AsyncExecuteRequest + .builder() + .request(request) + .requestContentPublisher(new SimpleHttpContentPublisher(request)) + .responseHandler( + new MLSdkAsyncHttpResponseHandler( + countDownLatch, + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + mlGuard + ) + ) + .build(); + AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); - throw e; + actionListener.onFailure(e); } catch (Throwable e) { log.error("Fail to execute http connector", e); - throw new MLException("Fail to execute http connector", e); + actionListener.onFailure(new MLException("Fail to execute http connector", e)); } } + private void validateHttpClientParameters(Map parameters) throws Exception { + String endpoint = connector.getPredictEndpoint(parameters); + URL url = new URL(endpoint); + String protocol = url.getProtocol(); + String host = url.getHost(); + int port = url.getPort(); + MLHttpClientFactory.validate(protocol, host, port); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java new file mode 100644 index 0000000000..28ed1bc200 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -0,0 +1,192 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.apache.http.HttpStatus; +import org.apache.logging.log4j.util.Strings; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.model.MLGuard; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.script.ScriptService; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; + +@Log4j2 +public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler { + @Getter + private Integer statusCode; + @Getter + private final StringBuilder responseBody = new StringBuilder(); + + private final ExecutionContext executionContext; + + private final ActionListener> actionListener; + + private final Map parameters; + + private final Map tensorOutputs; + + private final Connector connector; + + private final ScriptService scriptService; + + private final MLGuard mlGuard; + + public MLSdkAsyncHttpResponseHandler( + ExecutionContext executionContext, + ActionListener> actionListener, + Map parameters, + Map tensorOutputs, + Connector connector, + ScriptService scriptService, + MLGuard mlGuard + ) { + this.executionContext = executionContext; + this.actionListener = actionListener; + this.parameters = parameters; + this.tensorOutputs = tensorOutputs; + this.connector = connector; + this.scriptService = scriptService; + this.mlGuard = mlGuard; + } + + @Override + public void onHeaders(SdkHttpResponse response) { + SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; + log.debug("received response headers: " + sdkResponse.headers()); + this.statusCode = sdkResponse.statusCode(); + } + + @Override + public void onStream(Publisher stream) { + stream.subscribe(new MLResponseSubscriber()); + } + + @Override + public void onError(Throwable error) { + log.error(error.getMessage(), error); + RestStatus status = (statusCode == null) ? RestStatus.INTERNAL_SERVER_ERROR : RestStatus.fromCode(statusCode); + String errorMessage = "Error communicating with remote model: " + error.getMessage(); + actionListener.onFailure(new OpenSearchStatusException(errorMessage, status)); + } + + private void processResponse( + Integer statusCode, + String body, + Map parameters, + Map tensorOutputs + ) { + if (Strings.isBlank(body)) { + log.error("Remote model response body is empty!"); + if (executionContext.getExceptionHolder().get() == null) { + executionContext + .getExceptionHolder() + .compareAndSet(null, new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST)); + } + } else { + if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { + log.error("Remote server returned error code: {}", statusCode); + if (executionContext.getExceptionHolder().get() == null) { + executionContext + .getExceptionHolder() + .compareAndSet(null, new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode))); + } + } else { + try { + ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard); + tensors.setStatusCode(statusCode); + tensorOutputs.put(executionContext.getSequence(), tensors); + } catch (Exception e) { + log.error("Failed to process response body: {}", body, e); + if (executionContext.getExceptionHolder().get() == null) { + executionContext + .getExceptionHolder() + .compareAndSet(null, new MLException("Fail to execute predict in aws connector", e)); + } + } + } + } + } + + // Only all requests successful case will be processed here. + private void reOrderTensorResponses(Map tensorOutputs) { + ModelTensors[] modelTensors = new ModelTensors[tensorOutputs.size()]; + log.debug("Reordered tensor outputs size is {}", tensorOutputs.size()); + for (Map.Entry entry : tensorOutputs.entrySet()) { + modelTensors[entry.getKey()] = entry.getValue(); + } + actionListener.onResponse(Arrays.asList(modelTensors)); + } + + protected class MLResponseSubscriber implements Subscriber { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer)); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onError(Throwable t) { + log + .error( + "Error on receiving response body from remote: {}", + t instanceof NullPointerException ? "NullPointerException" : t.getMessage(), + t + ); + response(tensorOutputs); + } + + @Override + public void onComplete() { + response(tensorOutputs); + } + } + + private void response(Map tensors) { + processResponse(statusCode, responseBody.toString(), parameters, tensorOutputs); + executionContext.getCountDownLatch().countDown(); + // when countdown's count equals to 0 means all responses are received. + if (executionContext.getCountDownLatch().getCount() == 0) { + if (executionContext.getExceptionHolder().get() != null) { + actionListener.onFailure(executionContext.getExceptionHolder().get()); + return; + } + reOrderTensorResponses(tensors); + } else { + log.debug("Not all responses received, left response count is: " + executionContext.getCountDownLatch().getCount()); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 02f4a777d9..c8af4935a7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -8,17 +8,21 @@ import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.TokenBucket; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -30,52 +34,74 @@ import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.script.ScriptService; public interface RemoteConnectorExecutor { - default ModelTensorOutput executePredict(MLInput mlInput) { - List tensorOutputs = new ArrayList<>(); - - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - int processedDocs = 0; - while (processedDocs < textDocsInputDataSet.getDocs().size()) { - List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); - List tempTensorOutputs = new ArrayList<>(); + default void executePredict(MLInput mlInput, ActionListener actionListener) { + ActionListener> tensorActionListener = ActionListener.wrap(r -> { + actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r))); + }, actionListener::onFailure); + try { + Map modelTensors = new ConcurrentHashMap<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { + TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); + Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); + CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1()); + int sequence = 0; + for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize + .v2()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + preparePayloadAndInvokeRemoteModel( + MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) + .build(), + modelTensors, + new ExecutionContext(sequence++, countDownLatch, exceptionHolder), + tensorActionListener + ); + } + } else { preparePayloadAndInvokeRemoteModel( - MLInput - .builder() - .algorithm(FunctionName.TEXT_EMBEDDING) - .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) - .build(), - tempTensorOutputs + mlInput, + modelTensors, + new ExecutionContext(0, new CountDownLatch(1), exceptionHolder), + tensorActionListener ); - int tensorCount = 0; - if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { - tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); - } - // This is to support some model which takes N text docs and embedding size is less than N. - // We need to tell executor what's the step size for each model run. - Map parameters = getConnector().getParameters(); - if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { - int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); - // We need to check the parameter on runtime as parameter can be passed into predict request - if (stepSize <= 0) { - throw new IllegalArgumentException( - "Invalid parameter: input_docs_processed_step_size. It must be positive integer." - ); - } - processedDocs += stepSize; - } else { - processedDocs += Math.max(tensorCount, 1); + } + } catch (Exception e) { + actionListener.onFailure(e); + } + } + + /** + * Calculate the chunk size. + * @param textDocsInputDataSet + * @return Tuple of chunk size and step size. + */ + private Tuple calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) { + int textDocsLength = textDocsInputDataSet.getDocs().size(); + Map parameters = getConnector().getParameters(); + if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { + int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); + // We need to check the parameter on runtime as parameter can be passed into predict request + if (stepSize <= 0) { + throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); + } else { + boolean isDivisible = textDocsLength % stepSize == 0; + if (isDivisible) { + return Tuple.tuple(textDocsLength / stepSize, stepSize); } - tensorOutputs.addAll(tempTensorOutputs); + return Tuple.tuple(textDocsLength / stepSize + 1, stepSize); } } else { - preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); + // consider as batch. + return Tuple.tuple(1, textDocsLength); } - return new ModelTensorOutput(tensorOutputs); } default void setScriptService(ScriptService scriptService) {} @@ -104,7 +130,12 @@ default void setUserRateLimiterMap(Map userRateLimiterMap) default void setMlGuard(MLGuard mlGuard) {} - default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List tensorOutputs) { + default void preparePayloadAndInvokeRemoteModel( + MLInput mlInput, + Map tensorOutputs, + ExecutionContext countDownLatch, + ActionListener> actionListener + ) { Connector connector = getConnector(); Map parameters = new HashMap<>(); @@ -145,10 +176,16 @@ && getUserRateLimiterMap().get(user.getName()) != null if (getMlGuard() != null && !getMlGuard().validate(payload, MLGuard.Type.INPUT)) { throw new IllegalArgumentException("guardrails triggered for user input"); } - invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); + invokeRemoteModel(mlInput, parameters, payload, tensorOutputs, countDownLatch, actionListener); } } - void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs); - + void invokeRemoteModel( + MLInput mlInput, + Map parameters, + String payload, + Map tensorOutputs, + ExecutionContext countDownLatch, + ActionListener> actionListener + ); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 8774bcc40c..5828395641 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -10,6 +10,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.TokenBucket; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -18,6 +19,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.annotation.Function; @@ -55,18 +57,22 @@ public MLOutput predict(MLInput mlInput, MLModel model) { } @Override - public MLOutput predict(MLInput mlInput) { + public void asyncPredict(MLInput mlInput, ActionListener actionListener) { if (!isModelReady()) { - throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models//_deploy"); + actionListener + .onFailure( + new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models//_deploy") + ); + return; } try { - return connectorExecutor.executePredict(mlInput); + connectorExecutor.executePredict(mlInput, actionListener); } catch (RuntimeException e) { log.error("Failed to call remote model.", e); - throw e; + actionListener.onFailure(e); } catch (Throwable e) { log.error("Failed to call remote model.", e); - throw new MLException(e); + actionListener.onFailure(new MLException(e)); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index c981ebc184..339523b313 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -8,84 +8,69 @@ import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.Arrays; - -import org.apache.commons.lang3.math.NumberUtils; -import org.apache.http.HttpHost; -import org.apache.http.HttpRequest; -import org.apache.http.HttpResponse; -import org.apache.http.client.config.RequestConfig; -import org.apache.http.conn.UnsupportedSchemeException; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClientBuilder; -import org.apache.http.impl.client.LaxRedirectStrategy; -import org.apache.http.impl.conn.DefaultSchemePortResolver; -import org.apache.http.protocol.HttpContext; - -import com.google.common.annotations.VisibleForTesting; +import java.util.Locale; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @Log4j2 public class MLHttpClientFactory { - public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { - return createHttpClient(connectionTimeout, readTimeout, maxConnections); - } - - private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { - HttpClientBuilder builder = HttpClientBuilder.create(); - - // Only allow HTTP and HTTPS schemes - builder.setSchemePortResolver(new DefaultSchemePortResolver() { - @Override - public int resolve(HttpHost host) throws UnsupportedSchemeException { - validateSchemaAndPort(host); - return super.resolve(host); - } - }); - - builder.setDnsResolver(MLHttpClientFactory::validateIp); - - builder.setRedirectStrategy(new LaxRedirectStrategy() { - @Override - public boolean isRedirected(HttpRequest request, HttpResponse response, HttpContext context) { - // Do not follow redirects - return false; - } - }); - builder.setMaxConnTotal(maxConnections); - builder.setMaxConnPerRoute(maxConnections); - RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(connectionTimeout).setSocketTimeout(readTimeout).build(); - builder.setDefaultRequestConfig(requestConfig); - return builder.build(); + public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { + try { + return AccessController + .doPrivileged( + (PrivilegedExceptionAction) () -> NettyNioAsyncHttpClient + .builder() + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .maxConcurrency(maxConnections) + .build() + ); + } catch (PrivilegedActionException e) { + return null; + } } - @VisibleForTesting - protected static void validateSchemaAndPort(HttpHost host) { - String scheme = host.getSchemeName(); - if ("http".equalsIgnoreCase(scheme) || "https".equalsIgnoreCase(scheme)) { - String[] hostNamePort = host.getHostName().split(":"); - if (hostNamePort.length > 1 && NumberUtils.isDigits(hostNamePort[1])) { - int port = Integer.parseInt(hostNamePort[1]); - if (port < 0 || port > 65536) { - log.error("Remote inference port out of range: " + port); - throw new IllegalArgumentException("Port out of range: " + port); - } + /** + * Validate the input parameters, such as protocol, host and port. + * @param protocol The protocol supported in remote inference, currently only http and https are supported. + * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost. + * @param port The port number of the remote inference server, port number must be in range [0, 65536]. + * @throws UnknownHostException + */ + public static void validate(String protocol, String host, int port) throws UnknownHostException { + if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { + log.error("Remote inference protocol is not http or https: " + protocol); + throw new IllegalArgumentException("Protocol is not http or https: " + protocol); + } + // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. + if (port == -1) { + if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { + port = 80; + } else { + port = 443; } - } else { - log.error("Remote inference scheme not supported: " + scheme); - throw new IllegalArgumentException("Unsupported scheme: " + scheme); } + if (port < 0 || port > 65536) { + log.error("Remote inference port out of range: " + port); + throw new IllegalArgumentException("Port out of range: " + port); + } + validateIp(host); } - protected static InetAddress[] validateIp(String hostName) throws UnknownHostException { + private static void validateIp(String hostName) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if (hasPrivateIpAddress(addresses)) { log.error("Remote inference host name has private ip address: " + hostName); - throw new IllegalArgumentException(hostName); + throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); } - return addresses; } private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 69df5b03ae..5e1a9dfacb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,58 +5,48 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; +import java.lang.reflect.Field; import java.util.Arrays; import java.util.Map; -import java.util.Optional; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; -import org.opensearch.ml.common.connector.ConnectorClientConfig; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import software.amazon.awssdk.http.AbortableInputStream; -import software.amazon.awssdk.http.ExecutableHttpRequest; -import software.amazon.awssdk.http.HttpExecuteResponse; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.SdkHttpResponse; - public class AwsConnectorExecutorTest { @Rule @@ -73,16 +63,7 @@ public class AwsConnectorExecutorTest { ThreadContext threadContext; @Mock - ScriptService scriptService; - - @Mock - SdkHttpClient httpClient; - - @Mock - ExecutableHttpRequest httpRequest; - - @Mock - HttpExecuteResponse response; + ActionListener actionListener; Encryptor encryptor; @@ -100,7 +81,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") + .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); AwsConnector @@ -113,21 +94,12 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { } @Test - public void executePredict_RemoteInferenceInput_NullResponse() throws IOException { - exceptionRule.expect(OpenSearchStatusException.class); - exceptionRule.expectMessage("No response from model"); - when(response.responseBody()).thenReturn(Optional.empty()); - when(httpRequest.call()).thenReturn(response); - SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); - when(httpResponse.statusCode()).thenReturn(200); - when(response.httpResponse()).thenReturn(httpResponse); - when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - + public void executePredict_RemoteInferenceInput_EmptyIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") + .url("http:///mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap @@ -143,7 +115,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio .actions(Arrays.asList(predictAction)) .build(); connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); @@ -151,29 +123,22 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof NullPointerException; + assertEquals("host must not be null.", exceptionCaptor.getValue().getMessage()); } @Test - public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException { - exceptionRule.expect(OpenSearchStatusException.class); - exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}"); - String jsonString = "{\"message\":\"The security token included in the request is invalid\"}"; - InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); - AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); - when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); - when(httpRequest.call()).thenReturn(response); - SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); - when(httpResponse.statusCode()).thenReturn(403); - when(response.httpResponse()).thenReturn(httpResponse); - when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - + public void executePredict_TextDocsInferenceInput() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); @@ -188,39 +153,32 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio .actions(Arrays.asList(predictAction)) .build(); connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); } @Test - public void executePredict_RemoteInferenceInput() throws IOException { - String jsonString = "{\"key\":\"value\"}"; - InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); - AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); - when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); - SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); - when(httpResponse.statusCode()).thenReturn(200); - when(response.httpResponse()).thenReturn(httpResponse); - when(httpRequest.call()).thenReturn(response); - when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - + public void executePredict_TextDocsInferenceInput_withStepSize() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); - Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "2"); Connector connector = AwsConnector .awsConnectorBuilder() .name("test connector") @@ -231,42 +189,30 @@ public void executePredict_RemoteInferenceInput() throws IOException { .actions(Arrays.asList(predictAction)) .build(); connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + + MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener); } @Test - public void executePredict_TextDocsInferenceInput() throws IOException { - String jsonString = "{\"key\":\"value\"}"; - InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); - AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); - when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); - when(httpRequest.call()).thenReturn(response); - SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); - when(httpResponse.statusCode()).thenReturn(200); - when(response.httpResponse()).thenReturn(httpResponse); - when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - + public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .url("http://openai.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); @@ -281,36 +227,37 @@ public void executePredict_TextDocsInferenceInput() throws IOException { .actions(Arrays.asList(predictAction)) .build(); connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector); + Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(executor0, null); + AwsConnectorExecutor executor = spy(executor0); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof NullPointerException; } @Test - public void test_initialize() { + public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") + .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); - Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(20, 30000, 30000); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "-1"); Connector connector = AwsConnector .awsConnectorBuilder() .name("test connector") @@ -319,13 +266,20 @@ public void test_initialize() { .parameters(parameters) .credential(credential) .actions(Arrays.asList(predictAction)) - .connectorClientConfig(httpClientConfig) .build(); connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - Assert.assertEquals(20, executor.getConnector().getConnectorClientConfig().getMaxConnections().intValue()); - Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getConnectionTimeout().intValue()); - Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getReadTimeout().intValue()); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof IllegalArgumentException; } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 178d9aa722..cb7f8a4fe8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -168,7 +168,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() public void processOutput_NullResponse() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model response is null"); - ConnectorUtils.processOutput(null, null, null, null); + ConnectorUtils.processOutput(null, null, null, null, null); } @Test @@ -192,7 +192,7 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); @@ -224,7 +224,7 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 122402fea3..f4cd6715cf 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -5,46 +5,34 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; -import java.io.IOException; +import java.lang.reflect.Field; import java.util.Arrays; -import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; -import org.apache.http.HttpEntity; -import org.apache.http.ProtocolVersion; -import org.apache.http.StatusLine; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.message.BasicStatusLine; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.ingest.TestTemplateService; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; -import org.opensearch.ml.common.connector.MLPostProcessFunction; -import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.script.ScriptService; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.ml.common.output.model.ModelTensors; import com.google.common.collect.ImmutableMap; @@ -53,23 +41,7 @@ public class HttpJsonConnectorExecutorTest { public ExpectedException exceptionRule = ExpectedException.none(); @Mock - ThreadPool threadPool; - - @Mock - ScriptService scriptService; - - @Mock - CloseableHttpClient httpClient; - - @Mock - Client client; - - @Mock - CloseableHttpResponse response; - - Settings settings; - - ThreadContext threadContext; + private ActionListener> actionListener; @Before public void setUp() { @@ -78,13 +50,11 @@ public void setUp() { @Test public void invokeRemoteModel_WrongHttpMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("unsupported http method"); ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("wrong_method") - .url("http://test.com/mock") + .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Connector connector = HttpConnector @@ -94,17 +64,20 @@ public void invokeRemoteModel_WrongHttpMethod() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector, httpClient); - executor.invokeRemoteModel(null, null, null, null); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor.invokeRemoteModel(null, null, null, null, null, actionListener); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); + assertEquals("unsupported http method", captor.getValue().getMessage()); } @Test - public void executePredict_RemoteInferenceInput() throws IOException { + public void invokeRemoteModel_invalidIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") + .url("http://127.0.0.1/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Connector connector = HttpConnector @@ -114,44 +87,31 @@ public void executePredict_RemoteInferenceInput() throws IOException { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); - when(response.getEntity()).thenReturn(entity); - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); - Assert - .assertEquals( - "test result", - modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response") + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new HashMap<>(), + new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()), + actionListener ); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof IllegalArgumentException; + assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage()); } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { + public void invokeRemoteModel_Empty_payload() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") + .url("http://openai.com/mock") + .requestBody("") .build(); - when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); - when(response.getEntity()).thenReturn(entity); - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector .builder() .name("test connector") @@ -159,41 +119,31 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); - Assert - .assertEquals( - "test result", - modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response") + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + null, + new HashMap<>(), + new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()), + actionListener ); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof IllegalArgumentException; + assertEquals("Content length is 0. Aborting request to remote model", captor.getValue().getMessage()); } @Test - public void executePredict_TextDocsInput_LimitExceed() throws IOException { - exceptionRule.expect(OpenSearchStatusException.class); - exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); + public void invokeRemoteModel_get_request() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") + .method("GET") + .url("http://openai.com/mock") + .requestBody("") .build(); - when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); - when(response.getEntity()).thenReturn(entity); - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector .builder() .name("test connector") @@ -201,233 +151,82 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + null, + new HashMap<>(), + new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()), + actionListener + ); } @Test - public void executePredict_TextDocsInput() throws IOException { - String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - + public void invokeRemoteModel_post_request() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") + .url("http://openai.com/mock") + .requestBody("hello world") .build(); - HttpConnector connector = HttpConnector + Connector connector = HttpConnector .builder() .name("test connector") .version("1") .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - executor.setScriptService(scriptService); - when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\n" - + " \"object\": \"list\",\n" - + " \"data\": [\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 0,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " },\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 1,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " }\n" - + " ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" - + " \"usage\": {\n" - + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" - + " }\n" - + "}"; - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); - HttpEntity entity = new StringEntity(modelResponse); - when(response.getEntity()).thenReturn(entity); - when(executor.getConnector()).thenReturn(connector); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert - .assertArrayEquals( - new Number[] { -0.014555434, -0.002135904, 0.0035105038 }, - modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData() - ); - Assert - .assertArrayEquals( - new Number[] { -0.014555434, -0.002135904, 0.0035105038 }, - modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData() + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + "hello world", + new HashMap<>(), + new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()), + actionListener ); } @Test - public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException { - String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - + public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") + .url("http://openai.com/mock") + .requestBody("hello world") .build(); - Map parameters = ImmutableMap.of("input_docs_processed_step_size", "2"); - HttpConnector connector = HttpConnector + Connector connector = HttpConnector .builder() .name("test connector") .version("1") .protocol("http") - .parameters(parameters) .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - executor.setScriptService(scriptService); - when(httpClient.execute(any())).thenReturn(response); - // model takes 2 input docs, but only output 1 embedding - String modelResponse = "{\n" - + " \"object\": \"list\",\n" - + " \"data\": [\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 0,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " } ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" - + " \"usage\": {\n" - + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" - + " }\n" - + "}"; - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); - HttpEntity entity = new StringEntity(modelResponse); - when(response.getEntity()).thenReturn(entity); - when(executor.getConnector()).thenReturn(connector); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert - .assertArrayEquals( - new Number[] { -0.014555434, -0.002135904, 0.0035105038 }, - modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData() + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + Field httpClientField = HttpJsonConnectorExecutor.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(executor, null); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + "hello world", + new HashMap<>(), + new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()), + actionListener ); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue() instanceof NullPointerException; } - @Test - public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); - String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); - // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException - Map parameters = ImmutableMap.of("input_docs_processed_step_size", "-1"); - HttpConnector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .parameters(parameters) - .actions(Arrays.asList(predictAction)) - .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - executor.setScriptService(scriptService); - when(httpClient.execute(any())).thenReturn(response); - // model takes 2 input docs, but only output 1 embedding - String modelResponse = "{\n" - + " \"object\": \"list\",\n" - + " \"data\": [\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 0,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " } ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" - + " \"usage\": {\n" - + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" - + " }\n" - + "}"; - StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); - when(response.getStatusLine()).thenReturn(statusLine); - HttpEntity entity = new StringEntity(modelResponse); - when(response.getEntity()).thenReturn(entity); - when(executor.getConnector()).thenReturn(connector); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor - .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + private MLInput createMLInput() { + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + return MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java new file mode 100644 index 0000000000..68b5cdeb5f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -0,0 +1,422 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.script.ScriptService; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; + +import software.amazon.awssdk.http.HttpStatusCode; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; + +public class MLSdkAsyncHttpResponseHandlerTest { + private final ExecutionContext executionContext = new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()); + @Mock + private ActionListener> actionListener; + @Mock + private Map parameters; + private Map tensorOutputs = new ConcurrentHashMap<>(); + private Connector connector; + + private Connector noProcessFunctionConnector; + + @Mock + private SdkHttpFullResponse sdkHttpResponse; + @Mock + private ScriptService scriptService; + + private MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler; + + MLSdkAsyncHttpResponseHandler.MLResponseSubscriber responseSubscriber; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.OK); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + + ConnectorAction noProcessFunctionPredictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + noProcessFunctionConnector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(noProcessFunctionPredictAction)) + .build(); + mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber(); + } + + @Test + public void test_OnHeaders() { + mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 200; + } + + @Test + public void test_OnStream_with_postProcessFunction_bedRock() { + String response = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + test_OnHeaders(); // set the status code to non-null + mlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(actionListener).onResponse(captor.capture()); + assert captor.getValue().size() == 1; + assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8; + } + + @Test + public void test_OnStream_without_postProcessFunction() { + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap("{\"key\": \"hello world\"}".getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + MLSdkAsyncHttpResponseHandler noProcessFunctionMlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + tensorOutputs, + noProcessFunctionConnector, + scriptService, + null + ); + noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(actionListener).onResponse(captor.capture()); + assert captor.getValue().size() == 1; + assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("key").equals("hello world"); + } + + @Test + public void test_onError() { + test_OnHeaders(); // set the status code to non-null + mlSdkAsyncHttpResponseHandler.onError(new RuntimeException("runtime exception")); + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(captor.capture()); + assert captor.getValue() instanceof OpenSearchStatusException; + assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception"); + } + + @Test + public void test_MLSdkAsyncHttpResponseHandler_onError() { + mlSdkAsyncHttpResponseHandler.onError(new Exception("error")); + } + + @Test + public void test_onSubscribe() { + Subscription subscription = mock(Subscription.class); + responseSubscriber.onSubscribe(subscription); + } + + @Test + public void test_onNext() { + test_onSubscribe();// set the subscription to non-null. + responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes())); + assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world"); + } + + @Test + public void test_MLResponseSubscriber_onError() { + SdkHttpFullResponse response = mock(SdkHttpFullResponse.class); + when(response.statusCode()).thenReturn(500); + mlSdkAsyncHttpResponseHandler.onHeaders(response); + responseSubscriber.onError(new Exception("error")); + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof OpenSearchStatusException; + assert captor.getValue().getMessage().equals("No response from model"); + } + + @Test + public void test_onComplete_success() { + String response = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(actionListener).onResponse(captor.capture()); + assert captor.getValue().size() == 1; + assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8; + } + + @Test + public void test_onComplete_partial_success_exceptionSecond() { + AtomicReference exceptionHolder = new AtomicReference<>(); + String response1 = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + String response2 = "Model current status is: FAILED"; + CountDownLatch count = new CountDownLatch(2); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(0, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(1, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse1.statusCode()).thenReturn(200); + mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1); + Publisher stream1 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response1.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler1.onStream(stream1); + + SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse2.statusCode()).thenReturn(500); + mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2); + Publisher stream2 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response2.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler2.onStream(stream2); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED"); + assert captor.getValue().status().getStatus() == 500; + } + + @Test + public void test_onComplete_partial_success_exceptionFirst() { + AtomicReference exceptionHolder = new AtomicReference<>(); + String response1 = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + String response2 = "Model current status is: FAILED"; + CountDownLatch count = new CountDownLatch(2); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(0, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(1, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + + SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse2.statusCode()).thenReturn(500); + mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2); + Publisher stream2 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response2.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler2.onStream(stream2); + + SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse1.statusCode()).thenReturn(200); + mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1); + Publisher stream1 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response1.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler1.onStream(stream1); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED"); + assert captor.getValue().status().getStatus() == 500; + } + + @Test + public void test_onComplete_empty_response_body() { + mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap("".getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("No response from model"); + } + + @Test + public void test_onComplete_error_http_status() { + String error = "{\"message\": \"runtime error\"}"; + SdkHttpResponse response = mock(SdkHttpFullResponse.class); + when(response.statusCode()).thenReturn(HttpStatusCode.INTERNAL_SERVER_ERROR); + mlSdkAsyncHttpResponseHandler.onHeaders(response); + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(error.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof OpenSearchStatusException; + System.out.println(captor.getValue().getMessage()); + assert captor.getValue().getMessage().contains("runtime error"); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 149848327b..c14b329586 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -5,9 +5,12 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.Arrays; @@ -18,14 +21,17 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; @@ -60,20 +66,36 @@ public void predict_ModelNotDeployed() { } @Test - public void predict_NullConnectorExecutor() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("Model not ready yet"); + public void test_predict_throw_IllegalStateException() { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage("Method is not implemented"); remoteModel.predict(mlInput); } + @Test + public void predict_NullConnectorExecutor() { + ActionListener actionListener = mock(ActionListener.class); + remoteModel.asyncPredict(mlInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue() instanceof RuntimeException; + assertEquals( + "Model not ready yet. Please run this first: POST /_plugins/_ml/models//_deploy", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void predict_ModelDeployed_WrongInput() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("pre_process_function not defined in connector"); Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); - remoteModel.predict(mlInput); + ActionListener actionListener = mock(ActionListener.class); + remoteModel.asyncPredict(mlInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue() instanceof RuntimeException; + assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage()); } @Test @@ -105,8 +127,8 @@ public void initModel_WithHeader() { Assert.assertNotNull(executor); Assert.assertNull(decryptedHeaders); Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); - Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); - Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); + assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); + assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); remoteModel.close(); Assert.assertNull(remoteModel.getConnectorExecutor()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 179d11d2a1..138faf65e1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -8,6 +8,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; @@ -30,7 +32,9 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.opensearch.ResourceNotFoundException; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; @@ -627,6 +631,16 @@ public void predict_BeforeInitingModel() { textEmbeddingDenseModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), model); } + @Test + public void test_async_inference() { + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalStateException.class); + ActionListener actionListener = mock(ActionListener.class); + textEmbeddingDenseModel + .asyncPredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + verify(actionListener).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("Method is not implemented"); + } + @After public void tearDown() { FileUtils.deleteFileQuietly(mlCachePath); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index fa22837249..d1d9b42dcc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -7,98 +7,91 @@ import static org.junit.Assert.assertNotNull; -import java.net.UnknownHostException; +import java.time.Duration; -import org.apache.http.HttpHost; -import org.apache.http.impl.client.CloseableHttpClient; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + public class MLHttpClientFactoryTests { @Rule public ExpectedException expectedException = ExpectedException.none(); @Test - public void test_getCloseableHttpClient_success() { - CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(1000, 1000, 30); + public void test_getSdkAsyncHttpClient_success() { + SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100); assertNotNull(client); } @Test - public void test_validateIp_validIp_noException() throws UnknownHostException { - MLHttpClientFactory.validateIp("api.openai.com"); - } - - @Test - public void test_validateIp_invalidIp_throwException() throws UnknownHostException { - expectedException.expect(UnknownHostException.class); - MLHttpClientFactory.validateIp("www.makesureitisaunknownhost.com"); + public void test_validateIp_validIp_noException() throws Exception { + MLHttpClientFactory.validate("http", "api.openai.com", 80); } @Test - public void test_validateIp_privateIp_throwException() throws UnknownHostException { - expectedException.expect(IllegalArgumentException.class); - MLHttpClientFactory.validateIp("localhost"); - } - - @Test - public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException { + public void test_validateIp_rarePrivateIp_throwException() throws Exception { try { - MLHttpClientFactory.validateIp("0177.1"); + MLHttpClientFactory.validate("http", "0254.020.00.01", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validateIp("172.1048577"); + MLHttpClientFactory.validate("http", "172.1048577", 80); + } catch (Exception e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate("http", "2886729729", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validateIp("2886729729"); + MLHttpClientFactory.validate("http", "192.11010049", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validateIp("192.11010049"); + MLHttpClientFactory.validate("http", "3232300545", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validateIp("3232300545"); + MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validateIp("0:0:0:0:0:ffff:127.0.0.1"); + MLHttpClientFactory.validate("http", "153.24.76.232", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } } @Test - public void test_validateSchemaAndPort_success() { - HttpHost httpHost = new HttpHost("api.openai.com", 8080, "https"); - MLHttpClientFactory.validateSchemaAndPort(httpHost); + public void test_validateSchemaAndPort_success() throws Exception { + MLHttpClientFactory.validate("http", "api.openai.com", 80); } @Test - public void test_validateSchemaAndPort_notAllowedSchema_throwException() { + public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); - HttpHost httpHost = new HttpHost("api.openai.com", 8080, "ftp"); - MLHttpClientFactory.validateSchemaAndPort(httpHost); + MLHttpClientFactory.validate("ftp", "api.openai.com", 80); } @Test - public void test_validateSchemaAndPort_portNotInRange_throwException() { + public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); - HttpHost httpHost = new HttpHost("api.openai.com:65537", -1, "https"); - MLHttpClientFactory.validateSchemaAndPort(httpHost); + expectedException.expectMessage("Port out of range: 65537"); + MLHttpClientFactory.validate("https", "api.openai.com", 65537); } + } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 4d1f935e83..cb94668308 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -1891,6 +1891,12 @@ public T trackPredictDuration(String modelId, Supplier supplier) { return t; } + public void trackPredictDuration(String modelId, long startTime) { + long end = System.nanoTime(); + double durationInMs = (end - startTime) / 1e6; + modelCacheHelper.addModelInferenceDuration(modelId, durationInMs); + } + public FunctionName getModelFunctionName(String modelId) { return modelCacheHelper.getFunctionName(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index a19ffc4af3..16176f197e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -321,17 +321,26 @@ private void runPredict( if (!predictor.isModelReady()) { throw new IllegalArgumentException("Model not ready: " + modelId); } - MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput)); - if (output instanceof MLPredictionOutput) { - ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + if (mlInput.getAlgorithm() == FunctionName.REMOTE) { + long startTime = System.nanoTime(); + ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + handleAsyncMLTaskComplete(mlTask); + mlModelManager.trackPredictDuration(modelId, startTime); + internalListener.onResponse(output); + }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId)); + predictor.asyncPredict(mlInput, trackPredictDurationListener); + } else { + MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput)); + if (output instanceof MLPredictionOutput) { + ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + } + // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state + handleAsyncMLTaskComplete(mlTask); + internalListener.onResponse(new MLTaskResponse(output)); } - - // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state - handleAsyncMLTaskComplete(mlTask); - MLTaskResponse response = MLTaskResponse.builder().output(output).build(); - internalListener.onResponse(response); return; } catch (Exception e) { + log.error("Failed to predict model " + modelId, e); handlePredictFailure(mlTask, internalListener, e, false, modelId); return; } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 569aaee3c9..d42fa9ca65 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -61,6 +61,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; import org.junit.Before; import org.junit.Ignore; @@ -1140,6 +1141,35 @@ public void testRegisterModelMeta_FailedToInitIndexIfPresent() { verify(actionListener).onFailure(argumentCaptor.capture()); } + public void test_trackPredictDuration_sync() { + Supplier mockResult = () -> { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return "test"; + }; + String modelId = "test_model"; + modelManager.trackPredictDuration(modelId, mockResult); + ArgumentCaptor modelIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture()); + assert modelIdCaptor.getValue().equals(modelId); + assert durationCaptor.getValue() > 0; + } + + public void test_trackPredictDuration_async() { + String modelId = "test_model"; + long startTime = System.nanoTime(); + modelManager.trackPredictDuration(modelId, startTime); + ArgumentCaptor modelIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture()); + assert modelIdCaptor.getValue().equals(modelId); + assert durationCaptor.getValue() > 0; + } + private void setupForModelMeta() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index c4a9f4d61a..0f90528fab 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -15,6 +15,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -51,12 +52,16 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLInputDatasetHandler; @@ -306,6 +311,72 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { assertEquals("ModelId is invalid", argumentCaptor.getValue().getMessage()); } + public void testExecuteTask_OnLocalNode_remoteModel_success() { + setupMocks(true, false, false, false); + TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null); + MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(textDocsInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener + .onResponse(MLTaskResponse.builder().output(ModelTensorOutput.builder().mlModelOutputs(List.of()).build()).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + taskRunner.dispatchTask(FunctionName.REMOTE, textDocsInputRequest, transportService, listener); + verify(client, never()).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(listener).onResponse(argumentCaptor.capture()); + assert argumentCaptor.getValue().getOutput() instanceof ModelTensorOutput; + } + + public void testExecuteTask_OnLocalNode_localModel_success() { + setupMocks(true, false, false, false); + TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null); + MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" }); + when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class)); + taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener); + verify(client, never()).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(listener).onResponse(argumentCaptor.capture()); + assert argumentCaptor.getValue().getOutput() instanceof MLPredictionOutput; + } + + public void testExecuteTask_OnLocalNode_prediction_exception() { + setupMocks(true, false, false, false); + TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null); + MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build()) + .build(); + Predictable predictable = mock(Predictable.class); + when(mlModelManager.getPredictor(anyString())).thenReturn(predictable); + when(predictable.isModelReady()).thenThrow(new RuntimeException("runtime exception")); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" }); + when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class)); + taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener); + verify(client, never()).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue() instanceof RuntimeException; + assertEquals("runtime exception", argumentCaptor.getValue().getMessage()); + } + public void testExecuteTask_OnLocalNode_NullGetResponse() { setupMocks(true, false, false, true);