From d8fd07c6e182208102a76aea0c8aa769149bf18f Mon Sep 17 00:00:00 2001
From: zane-neo <zaniu@amazon.com>
Date: Tue, 30 Apr 2024 12:57:21 +0800
Subject: [PATCH] Change httpclient to async (#1958)

* Change httpclient from sync to async

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change from CRTAsyncHttpClient to NettyAsyncHttpClient

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add publisher to request

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change sync httpclient to async

Signed-off-by: zane-neo <zaniu@amazon.com>

* Handle error case and return error response in actionLListener

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix no response when exception

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add content type header

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix issues found in functional test

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix no response issue in functional test

Signed-off-by: zane-neo <zaniu@amazon.com>

* fix default step size error

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add track inference duration for async httpclient

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change client appsec highlight issues implementation for async httpclient

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Remove unused file

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change error code to honor remote service error code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add more UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change SSRF code to make it correct for return error stattus

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure UTs and add more UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure ITs

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix partial success response not correct issue

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure ITs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add more UTs to increase code coverage

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change url regex

Signed-off-by: zane-neo <zaniu@amazon.com>

* Address comments

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix failure UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add UT for httpclientFactory throw exception when creating httpclient

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Address comments and add modelTensor status code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Address comments

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add status code to process error response

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Rebase main after connector level http parameter support

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix UT

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change error message when remote model return empty and chaange the behavior when one of the requests fails

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add comments\

Signed-off-by: zane-neo <zaniu@amazon.com>

* Remove redundant builder and change the error code check

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add more UTs for throw exception cases

Signed-off-by: zane-neo <zaniu@amazon.com>

* fix failure UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix test cases since the error message change

Signed-off-by: zane-neo <zaniu@amazon.com>

* Rebase code

Signed-off-by: zane-neo <zaniu@amazon.com>

* fix failure IT

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add more UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix duplicate response to client issue

Signed-off-by: zane-neo <zaniu@amazon.com>

* fix duplicate response in channel

Signed-off-by: zane-neo <zaniu@amazon.com>

* change code for all successfully responses case

Signed-off-by: zane-neo <zaniu@amazon.com>

* Address comments

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

* Increase nio httpclient version to fix vulnerbility

Signed-off-by: zane-neo <zaniu@amazon.com>

* Change validate localhost logic to same with existing code

Signed-off-by: zane-neo <zaniu@amazon.com>

* change method signature to private

Signed-off-by: zane-neo <zaniu@amazon.com>

* format code

Signed-off-by: zane-neo <zaniu@amazon.com>

---------

Signed-off-by: zane-neo <zaniu@amazon.com>
---
 ml-algorithms/build.gradle                    |   3 +-
 .../org/opensearch/ml/engine/Predictable.java |  10 +-
 .../remote/AwsConnectorExecutor.java          | 134 ++----
 .../algorithms/remote/ConnectorUtils.java     |  49 +-
 .../algorithms/remote/ExecutionContext.java   |  32 ++
 .../remote/HttpJsonConnectorExecutor.java     | 139 +++---
 .../remote/MLSdkAsyncHttpResponseHandler.java | 192 ++++++++
 .../remote/RemoteConnectorExecutor.java       | 119 +++--
 .../engine/algorithms/remote/RemoteModel.java |  16 +-
 .../httpclient/MLHttpClientFactory.java       | 107 ++---
 .../remote/AwsConnectorExecutorTest.java      | 182 +++-----
 .../algorithms/remote/ConnectorUtilsTest.java |   6 +-
 .../remote/HttpJsonConnectorExecutorTest.java | 399 ++++-------------
 .../MLSdkAsyncHttpResponseHandlerTest.java    | 422 ++++++++++++++++++
 .../algorithms/remote/RemoteModelTest.java    |  38 +-
 .../TextEmbeddingDenseModelTest.java          |  14 +
 .../httpclient/MLHttpClientFactoryTests.java  |  63 ++-
 .../opensearch/ml/model/MLModelManager.java   |   6 +
 .../ml/task/MLPredictTaskRunner.java          |  25 +-
 .../ml/model/MLModelManagerTests.java         |  30 ++
 .../ml/task/MLPredictTaskRunnerTests.java     |  71 +++
 21 files changed, 1298 insertions(+), 759 deletions(-)
 create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java
 create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java
 create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java

diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle
index 8454f1b519..bb900fe99e 100644
--- a/ml-algorithms/build.gradle
+++ b/ml-algorithms/build.gradle
@@ -63,12 +63,13 @@ dependencies {
         }
     }
 
-    implementation platform('software.amazon.awssdk:bom:2.21.15')
+    implementation platform('software.amazon.awssdk:bom:2.25.40')
     implementation 'software.amazon.awssdk:auth'
     implementation 'software.amazon.awssdk:apache-client'
     implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
     implementation 'com.jayway.jsonpath:json-path:2.9.0'
     implementation group: 'org.json', name: 'json', version: '20231013'
+    implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40'
 }
 
 lombok {
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java
index 38c5889c78..5ee1e4d7a1 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java
@@ -7,9 +7,11 @@
 
 import java.util.Map;
 
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.MLModel;
 import org.opensearch.ml.common.input.MLInput;
 import org.opensearch.ml.common.output.MLOutput;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.ml.engine.encryptor.Encryptor;
 
 /**
@@ -31,7 +33,13 @@ public interface Predictable {
      * @param mlInput input data
      * @return predicted results
      */
-    MLOutput predict(MLInput mlInput);
+    default MLOutput predict(MLInput mlInput) {
+        throw new IllegalStateException("Method is not implemented");
+    }
+
+    default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
+        actionListener.onFailure(new IllegalStateException("Method is not implemented"));
+    }
 
     /**
      * Init model (load model into memory) with ML model content and params.
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java
index ea53b2c08e..916e486654 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java
@@ -5,25 +5,19 @@
 
 package org.opensearch.ml.engine.algorithms.remote;
 
-import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
 import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
-import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
 import static software.amazon.awssdk.http.SdkHttpMethod.POST;
 
-import java.io.BufferedReader;
-import java.io.InputStreamReader;
-import java.net.URI;
-import java.nio.charset.StandardCharsets;
 import java.security.AccessController;
 import java.security.PrivilegedExceptionAction;
 import java.time.Duration;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 
-import org.opensearch.OpenSearchStatusException;
 import org.opensearch.client.Client;
 import org.opensearch.common.util.TokenBucket;
-import org.opensearch.core.rest.RestStatus;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.connector.AwsConnector;
 import org.opensearch.ml.common.connector.Connector;
 import org.opensearch.ml.common.exception.MLException;
@@ -31,20 +25,16 @@
 import org.opensearch.ml.common.model.MLGuard;
 import org.opensearch.ml.common.output.model.ModelTensors;
 import org.opensearch.ml.engine.annotation.ConnectorExecutor;
+import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
 import org.opensearch.script.ScriptService;
 
 import lombok.Getter;
 import lombok.Setter;
 import lombok.extern.log4j.Log4j2;
-import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
-import software.amazon.awssdk.core.sync.RequestBody;
-import software.amazon.awssdk.http.AbortableInputStream;
-import software.amazon.awssdk.http.HttpExecuteRequest;
-import software.amazon.awssdk.http.HttpExecuteResponse;
-import software.amazon.awssdk.http.SdkHttpClient;
-import software.amazon.awssdk.http.SdkHttpConfigurationOption;
+import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
 import software.amazon.awssdk.http.SdkHttpFullRequest;
-import software.amazon.awssdk.utils.AttributeMap;
+import software.amazon.awssdk.http.async.AsyncExecuteRequest;
+import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
 
 @Log4j2
 @ConnectorExecutor(AWS_SIGV4)
@@ -52,7 +42,6 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
 
     @Getter
     private AwsConnector connector;
-    private SdkHttpClient httpClient;
     @Setter
     @Getter
     private ScriptService scriptService;
@@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
     @Getter
     private MLGuard mlGuard;
 
-    public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
-        this.connector = (AwsConnector) connector;
-        this.httpClient = httpClient;
-    }
+    private SdkAsyncHttpClient httpClient;
 
     public AwsConnectorExecutor(Connector connector) {
         super.initialize(connector);
         this.connector = (AwsConnector) connector;
-        Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout());
-        Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout());
-        try (
-            AttributeMap attributeMap = AttributeMap
-                .builder()
-                .put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
-                .put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
-                .put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections())
-                .build()
-        ) {
-            log
-                .info(
-                    "Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}",
-                    connectionTimeout,
-                    readTimeout,
-                    super.getConnectorClientConfig().getMaxConnections()
-                );
-            this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
-        } catch (RuntimeException e) {
-            log.error("Error initializing AWS connector HTTP client.", e);
-            throw e;
-        } catch (Throwable e) {
-            log.error("Error initializing AWS connector HTTP client.", e);
-            throw new MLException(e);
-        }
+        Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
+        Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
+        Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
+        this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
     }
 
     @SuppressWarnings("removal")
     @Override
-    public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
+    public void invokeRemoteModel(
+        MLInput mlInput,
+        Map<String, String> parameters,
+        String payload,
+        Map<Integer, ModelTensors> tensorOutputs,
+        ExecutionContext countDownLatch,
+        ActionListener<List<ModelTensors>> actionListener
+    ) {
         try {
-            String endpoint = connector.getPredictEndpoint(parameters);
-            RequestBody requestBody = RequestBody.fromString(payload);
-
-            SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
-                .builder()
-                .method(POST)
-                .uri(URI.create(endpoint))
-                .contentStreamProvider(requestBody.contentStreamProvider());
-            Map<String, String> headers = connector.getDecryptedHeaders();
-            if (headers != null) {
-                for (String key : headers.keySet()) {
-                    builder.putHeader(key, headers.get(key));
-                }
-            }
-            SdkHttpFullRequest request = builder.build();
-            HttpExecuteRequest executeRequest = HttpExecuteRequest
+            SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
+            AsyncExecuteRequest executeRequest = AsyncExecuteRequest
                 .builder()
                 .request(signRequest(request))
-                .contentStreamProvider(request.contentStreamProvider().orElse(null))
+                .requestContentPublisher(new SimpleHttpContentPublisher(request))
+                .responseHandler(
+                    new MLSdkAsyncHttpResponseHandler(
+                        countDownLatch,
+                        actionListener,
+                        parameters,
+                        tensorOutputs,
+                        connector,
+                        scriptService,
+                        mlGuard
+                    )
+                )
                 .build();
-
-            HttpExecuteResponse response = AccessController
-                .doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> httpClient.prepareRequest(executeRequest).call());
-            int statusCode = response.httpResponse().statusCode();
-
-            AbortableInputStream body = null;
-            if (response.responseBody().isPresent()) {
-                body = response.responseBody().get();
-            }
-
-            StringBuilder responseBuilder = new StringBuilder();
-            if (body != null) {
-                try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) {
-                    String line;
-                    while ((line = reader.readLine()) != null) {
-                        responseBuilder.append(line);
-                    }
-                }
-            } else {
-                throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
-            }
-            String modelResponse = responseBuilder.toString();
-            if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
-                throw new IllegalArgumentException("guardrails triggered for LLM output");
-            }
-            if (statusCode < 200 || statusCode >= 300) {
-                throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
-            }
-
-            ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
-            tensors.setStatusCode(statusCode);
-            tensorOutputs.add(tensors);
+            AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
         } catch (RuntimeException exception) {
             log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
-            throw exception;
+            actionListener.onFailure(exception);
         } catch (Throwable e) {
             log.error("Failed to execute predict in aws connector", e);
-            throw new MLException("Fail to execute predict in aws connector", e);
+            actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
         }
     }
 
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java
index 49e6ef7d69..a6181e1b2f 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java
@@ -15,6 +15,8 @@
 import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;
 
 import java.io.IOException;
+import java.net.URI;
+import java.nio.charset.Charset;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -34,6 +36,7 @@
 import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
 import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
+import org.opensearch.ml.common.model.MLGuard;
 import org.opensearch.ml.common.output.model.ModelTensor;
 import org.opensearch.ml.common.output.model.ModelTensors;
 import org.opensearch.script.ScriptService;
@@ -46,7 +49,9 @@
 import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
 import software.amazon.awssdk.auth.signer.Aws4Signer;
 import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
+import software.amazon.awssdk.core.sync.RequestBody;
 import software.amazon.awssdk.http.SdkHttpFullRequest;
+import software.amazon.awssdk.http.SdkHttpMethod;
 import software.amazon.awssdk.regions.Region;
 
 @Log4j2
@@ -179,11 +184,15 @@ public static ModelTensors processOutput(
         String modelResponse,
         Connector connector,
         ScriptService scriptService,
-        Map<String, String> parameters
+        Map<String, String> parameters,
+        MLGuard mlGuard
     ) throws IOException {
         if (modelResponse == null) {
             throw new IllegalArgumentException("model response is null");
         }
+        if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) {
+            throw new IllegalArgumentException("guardrails triggered for LLM output");
+        }
         List<ModelTensor> modelTensors = new ArrayList<>();
         Optional<ConnectorAction> predictAction = connector.findPredictAction();
         if (predictAction.isEmpty()) {
@@ -252,4 +261,42 @@ public static SdkHttpFullRequest signRequest(
 
         return signer.sign(request, params);
     }
+
+    public static SdkHttpFullRequest buildSdkRequest(
+        Connector connector,
+        Map<String, String> parameters,
+        String payload,
+        SdkHttpMethod method
+    ) {
+        String charset = parameters.getOrDefault("charset", "UTF-8");
+        RequestBody requestBody;
+        if (payload != null) {
+            requestBody = RequestBody.fromString(payload, Charset.forName(charset));
+        } else {
+            requestBody = RequestBody.empty();
+        }
+        if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) {
+            log.error("Content length is 0. Aborting request to remote model");
+            throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
+        }
+        String endpoint = connector.getPredictEndpoint(parameters);
+        SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
+            .builder()
+            .method(method)
+            .uri(URI.create(endpoint))
+            .contentStreamProvider(requestBody.contentStreamProvider());
+        Map<String, String> headers = connector.getDecryptedHeaders();
+        if (headers != null) {
+            for (String key : headers.keySet()) {
+                builder.putHeader(key, headers.get(key));
+            }
+        }
+        if (builder.matchingHeaders("Content-Type").isEmpty()) {
+            builder.putHeader("Content-Type", "application/json");
+        }
+        if (builder.matchingHeaders("Content-Length").isEmpty()) {
+            builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString());
+        }
+        return builder.build();
+    }
 }
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java
new file mode 100644
index 0000000000..66c828bead
--- /dev/null
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java
@@ -0,0 +1,32 @@
+/*
+ *
+ *  * Copyright OpenSearch Contributors
+ *  * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.ml.engine.algorithms.remote;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+/**
+ * This class encapsulates several parameters that are used in a split-batch request case.
+ * A batch request is that in neural-search side multiple fields are send in one request to ml-commons,
+ * but the remote model doesn't accept list of string inputs so in ml-commons the request needs split.
+ * sequence is used to identify the index of the split request.
+ * countDownLatch is used to wait for all the split requests to finish.
+ * exceptionHolder is used to hold any exception thrown in a split-batch request.
+ */
+@Data
+@AllArgsConstructor
+public class ExecutionContext {
+    // Should never be null
+    private int sequence;
+    private CountDownLatch countDownLatch;
+    // This is to hold any exception thrown in a split-batch request
+    private AtomicReference<Exception> exceptionHolder;
+}
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java
index 080e8d7553..b92a8d57d4 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java
@@ -5,29 +5,22 @@
 
 package org.opensearch.ml.engine.algorithms.remote;
 
-import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
 import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
-import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
+import static software.amazon.awssdk.http.SdkHttpMethod.GET;
+import static software.amazon.awssdk.http.SdkHttpMethod.POST;
 
+import java.net.URL;
 import java.security.AccessController;
 import java.security.PrivilegedExceptionAction;
+import java.time.Duration;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.CompletableFuture;
 
-import org.apache.http.HttpEntity;
-import org.apache.http.client.methods.CloseableHttpResponse;
-import org.apache.http.client.methods.HttpGet;
-import org.apache.http.client.methods.HttpPost;
-import org.apache.http.client.methods.HttpUriRequest;
-import org.apache.http.entity.StringEntity;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.util.EntityUtils;
-import org.opensearch.OpenSearchStatusException;
 import org.opensearch.client.Client;
 import org.opensearch.common.util.TokenBucket;
-import org.opensearch.core.rest.RestStatus;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.connector.Connector;
 import org.opensearch.ml.common.connector.HttpConnector;
 import org.opensearch.ml.common.exception.MLException;
@@ -41,6 +34,10 @@
 import lombok.Getter;
 import lombok.Setter;
 import lombok.extern.log4j.Log4j2;
+import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
+import software.amazon.awssdk.http.SdkHttpFullRequest;
+import software.amazon.awssdk.http.async.AsyncExecuteRequest;
+import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
 
 @Log4j2
 @ConnectorExecutor(HTTP)
@@ -65,98 +62,74 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {
     @Getter
     private MLGuard mlGuard;
 
-    private CloseableHttpClient httpClient;
+    private SdkAsyncHttpClient httpClient;
 
     public HttpJsonConnectorExecutor(Connector connector) {
         super.initialize(connector);
         this.connector = (HttpConnector) connector;
-        this.httpClient = MLHttpClientFactory
-            .getCloseableHttpClient(
-                super.getConnectorClientConfig().getConnectionTimeout(),
-                super.getConnectorClientConfig().getReadTimeout(),
-                super.getConnectorClientConfig().getMaxConnections()
-            );
-    }
-
-    public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) {
-        this(connector);
-        this.httpClient = httpClient;
+        Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
+        Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
+        Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
+        this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
     }
 
     @SuppressWarnings("removal")
     @Override
-    public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
+    public void invokeRemoteModel(
+        MLInput mlInput,
+        Map<String, String> parameters,
+        String payload,
+        Map<Integer, ModelTensors> tensorOutputs,
+        ExecutionContext countDownLatch,
+        ActionListener<List<ModelTensors>> actionListener
+    ) {
         try {
-            AtomicReference<String> responseRef = new AtomicReference<>("");
-            AtomicReference<Integer> statusCodeRef = new AtomicReference<>();
-
-            HttpUriRequest request;
+            SdkHttpFullRequest request;
             switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
                 case "POST":
-                    try {
-                        String predictEndpoint = connector.getPredictEndpoint(parameters);
-                        request = new HttpPost(predictEndpoint);
-                        String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8";
-                        HttpEntity entity = new StringEntity(payload, charset);
-                        ((HttpPost) request).setEntity(entity);
-                    } catch (Exception e) {
-                        throw new MLException("Failed to create http request for remote model", e);
-                    }
+                    log.debug("original payload to remote model: " + payload);
+                    validateHttpClientParameters(parameters);
+                    request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
                     break;
                 case "GET":
-                    try {
-                        request = new HttpGet(connector.getPredictEndpoint(parameters));
-                    } catch (Exception e) {
-                        throw new MLException("Failed to create http request for remote model", e);
-                    }
+                    validateHttpClientParameters(parameters);
+                    request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET);
                     break;
                 default:
                     throw new IllegalArgumentException("unsupported http method");
             }
-
-            Map<String, ?> headers = connector.getDecryptedHeaders();
-            boolean hasContentTypeHeader = false;
-            if (headers != null) {
-                for (String key : headers.keySet()) {
-                    request.addHeader(key, (String) headers.get(key));
-                    if (key.toLowerCase().equals("Content-Type")) {
-                        hasContentTypeHeader = true;
-                    }
-                }
-            }
-            if (!hasContentTypeHeader) {
-                request.addHeader("Content-Type", "application/json");
-            }
-
-            AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
-                try (CloseableHttpResponse response = httpClient.execute(request)) {
-                    HttpEntity responseEntity = response.getEntity();
-                    String responseBody = EntityUtils.toString(responseEntity);
-                    EntityUtils.consume(responseEntity);
-                    responseRef.set(responseBody);
-                    statusCodeRef.set(response.getStatusLine().getStatusCode());
-                }
-                return null;
-            });
-            String modelResponse = responseRef.get();
-            if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
-                throw new IllegalArgumentException("guardrails triggered for LLM output");
-            }
-            Integer statusCode = statusCodeRef.get();
-            if (statusCode < 200 || statusCode >= 300) {
-                throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
-            }
-
-            ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
-            tensors.setStatusCode(statusCode);
-            tensorOutputs.add(tensors);
+            AsyncExecuteRequest executeRequest = AsyncExecuteRequest
+                .builder()
+                .request(request)
+                .requestContentPublisher(new SimpleHttpContentPublisher(request))
+                .responseHandler(
+                    new MLSdkAsyncHttpResponseHandler(
+                        countDownLatch,
+                        actionListener,
+                        parameters,
+                        tensorOutputs,
+                        connector,
+                        scriptService,
+                        mlGuard
+                    )
+                )
+                .build();
+            AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
         } catch (RuntimeException e) {
             log.error("Fail to execute http connector", e);
-            throw e;
+            actionListener.onFailure(e);
         } catch (Throwable e) {
             log.error("Fail to execute http connector", e);
-            throw new MLException("Fail to execute http connector", e);
+            actionListener.onFailure(new MLException("Fail to execute http connector", e));
         }
     }
 
+    private void validateHttpClientParameters(Map<String, String> parameters) throws Exception {
+        String endpoint = connector.getPredictEndpoint(parameters);
+        URL url = new URL(endpoint);
+        String protocol = url.getProtocol();
+        String host = url.getHost();
+        int port = url.getPort();
+        MLHttpClientFactory.validate(protocol, host, port);
+    }
 }
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java
new file mode 100644
index 0000000000..28ed1bc200
--- /dev/null
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java
@@ -0,0 +1,192 @@
+/*
+ *
+ *  * Copyright OpenSearch Contributors
+ *  * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.ml.engine.algorithms.remote;
+
+import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
+import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.http.HttpStatus;
+import org.apache.logging.log4j.util.Strings;
+import org.opensearch.OpenSearchStatusException;
+import org.opensearch.core.action.ActionListener;
+import org.opensearch.core.rest.RestStatus;
+import org.opensearch.ml.common.connector.Connector;
+import org.opensearch.ml.common.exception.MLException;
+import org.opensearch.ml.common.model.MLGuard;
+import org.opensearch.ml.common.output.model.ModelTensors;
+import org.opensearch.script.ScriptService;
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+
+import lombok.Getter;
+import lombok.extern.log4j.Log4j2;
+import software.amazon.awssdk.http.SdkHttpFullResponse;
+import software.amazon.awssdk.http.SdkHttpResponse;
+import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
+
+@Log4j2
+public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler {
+    @Getter
+    private Integer statusCode;
+    @Getter
+    private final StringBuilder responseBody = new StringBuilder();
+
+    private final ExecutionContext executionContext;
+
+    private final ActionListener<List<ModelTensors>> actionListener;
+
+    private final Map<String, String> parameters;
+
+    private final Map<Integer, ModelTensors> tensorOutputs;
+
+    private final Connector connector;
+
+    private final ScriptService scriptService;
+
+    private final MLGuard mlGuard;
+
+    public MLSdkAsyncHttpResponseHandler(
+        ExecutionContext executionContext,
+        ActionListener<List<ModelTensors>> actionListener,
+        Map<String, String> parameters,
+        Map<Integer, ModelTensors> tensorOutputs,
+        Connector connector,
+        ScriptService scriptService,
+        MLGuard mlGuard
+    ) {
+        this.executionContext = executionContext;
+        this.actionListener = actionListener;
+        this.parameters = parameters;
+        this.tensorOutputs = tensorOutputs;
+        this.connector = connector;
+        this.scriptService = scriptService;
+        this.mlGuard = mlGuard;
+    }
+
+    @Override
+    public void onHeaders(SdkHttpResponse response) {
+        SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response;
+        log.debug("received response headers: " + sdkResponse.headers());
+        this.statusCode = sdkResponse.statusCode();
+    }
+
+    @Override
+    public void onStream(Publisher<ByteBuffer> stream) {
+        stream.subscribe(new MLResponseSubscriber());
+    }
+
+    @Override
+    public void onError(Throwable error) {
+        log.error(error.getMessage(), error);
+        RestStatus status = (statusCode == null) ? RestStatus.INTERNAL_SERVER_ERROR : RestStatus.fromCode(statusCode);
+        String errorMessage = "Error communicating with remote model: " + error.getMessage();
+        actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
+    }
+
+    private void processResponse(
+        Integer statusCode,
+        String body,
+        Map<String, String> parameters,
+        Map<Integer, ModelTensors> tensorOutputs
+    ) {
+        if (Strings.isBlank(body)) {
+            log.error("Remote model response body is empty!");
+            if (executionContext.getExceptionHolder().get() == null) {
+                executionContext
+                    .getExceptionHolder()
+                    .compareAndSet(null, new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST));
+            }
+        } else {
+            if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
+                log.error("Remote server returned error code: {}", statusCode);
+                if (executionContext.getExceptionHolder().get() == null) {
+                    executionContext
+                        .getExceptionHolder()
+                        .compareAndSet(null, new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode)));
+                }
+            } else {
+                try {
+                    ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard);
+                    tensors.setStatusCode(statusCode);
+                    tensorOutputs.put(executionContext.getSequence(), tensors);
+                } catch (Exception e) {
+                    log.error("Failed to process response body: {}", body, e);
+                    if (executionContext.getExceptionHolder().get() == null) {
+                        executionContext
+                            .getExceptionHolder()
+                            .compareAndSet(null, new MLException("Fail to execute predict in aws connector", e));
+                    }
+                }
+            }
+        }
+    }
+
+    // Only all requests successful case will be processed here.
+    private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
+        ModelTensors[] modelTensors = new ModelTensors[tensorOutputs.size()];
+        log.debug("Reordered tensor outputs size is {}", tensorOutputs.size());
+        for (Map.Entry<Integer, ModelTensors> entry : tensorOutputs.entrySet()) {
+            modelTensors[entry.getKey()] = entry.getValue();
+        }
+        actionListener.onResponse(Arrays.asList(modelTensors));
+    }
+
+    protected class MLResponseSubscriber implements Subscriber<ByteBuffer> {
+        private Subscription subscription;
+
+        @Override
+        public void onSubscribe(Subscription s) {
+            this.subscription = s;
+            s.request(Long.MAX_VALUE);
+        }
+
+        @Override
+        public void onNext(ByteBuffer byteBuffer) {
+            responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
+            subscription.request(Long.MAX_VALUE);
+        }
+
+        @Override
+        public void onError(Throwable t) {
+            log
+                .error(
+                    "Error on receiving response body from remote: {}",
+                    t instanceof NullPointerException ? "NullPointerException" : t.getMessage(),
+                    t
+                );
+            response(tensorOutputs);
+        }
+
+        @Override
+        public void onComplete() {
+            response(tensorOutputs);
+        }
+    }
+
+    private void response(Map<Integer, ModelTensors> tensors) {
+        processResponse(statusCode, responseBody.toString(), parameters, tensorOutputs);
+        executionContext.getCountDownLatch().countDown();
+        // when countdown's count equals to 0 means all responses are received.
+        if (executionContext.getCountDownLatch().getCount() == 0) {
+            if (executionContext.getExceptionHolder().get() != null) {
+                actionListener.onFailure(executionContext.getExceptionHolder().get());
+                return;
+            }
+            reOrderTensorResponses(tensors);
+        } else {
+            log.debug("Not all responses received, left response count is: " + executionContext.getCountDownLatch().getCount());
+        }
+    }
+}
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java
index 02f4a777d9..c8af4935a7 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java
@@ -8,17 +8,21 @@
 import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
 import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
 
-import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
 
 import org.opensearch.OpenSearchStatusException;
 import org.opensearch.client.Client;
 import org.opensearch.cluster.service.ClusterService;
+import org.opensearch.common.collect.Tuple;
 import org.opensearch.common.util.TokenBucket;
 import org.opensearch.commons.ConfigConstants;
 import org.opensearch.commons.authuser.User;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.core.rest.RestStatus;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.ml.common.FunctionName;
@@ -30,52 +34,74 @@
 import org.opensearch.ml.common.model.MLGuard;
 import org.opensearch.ml.common.output.model.ModelTensorOutput;
 import org.opensearch.ml.common.output.model.ModelTensors;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.script.ScriptService;
 
 public interface RemoteConnectorExecutor {
 
-    default ModelTensorOutput executePredict(MLInput mlInput) {
-        List<ModelTensors> tensorOutputs = new ArrayList<>();
-
-        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
-            TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
-            int processedDocs = 0;
-            while (processedDocs < textDocsInputDataSet.getDocs().size()) {
-                List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
-                List<ModelTensors> tempTensorOutputs = new ArrayList<>();
+    default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
+        ActionListener<List<ModelTensors>> tensorActionListener = ActionListener.wrap(r -> {
+            actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r)));
+        }, actionListener::onFailure);
+        try {
+            Map<Integer, ModelTensors> modelTensors = new ConcurrentHashMap<>();
+            AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+            if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
+                TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
+                Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
+                CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1());
+                int sequence = 0;
+                for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize
+                    .v2()) {
+                    List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
+                    preparePayloadAndInvokeRemoteModel(
+                        MLInput
+                            .builder()
+                            .algorithm(FunctionName.TEXT_EMBEDDING)
+                            .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
+                            .build(),
+                        modelTensors,
+                        new ExecutionContext(sequence++, countDownLatch, exceptionHolder),
+                        tensorActionListener
+                    );
+                }
+            } else {
                 preparePayloadAndInvokeRemoteModel(
-                    MLInput
-                        .builder()
-                        .algorithm(FunctionName.TEXT_EMBEDDING)
-                        .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
-                        .build(),
-                    tempTensorOutputs
+                    mlInput,
+                    modelTensors,
+                    new ExecutionContext(0, new CountDownLatch(1), exceptionHolder),
+                    tensorActionListener
                 );
-                int tensorCount = 0;
-                if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
-                    tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
-                }
-                // This is to support some model which takes N text docs and embedding size is less than N.
-                // We need to tell executor what's the step size for each model run.
-                Map<String, String> parameters = getConnector().getParameters();
-                if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
-                    int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
-                    // We need to check the parameter on runtime as parameter can be passed into predict request
-                    if (stepSize <= 0) {
-                        throw new IllegalArgumentException(
-                            "Invalid parameter: input_docs_processed_step_size. It must be positive integer."
-                        );
-                    }
-                    processedDocs += stepSize;
-                } else {
-                    processedDocs += Math.max(tensorCount, 1);
+            }
+        } catch (Exception e) {
+            actionListener.onFailure(e);
+        }
+    }
+
+    /**
+     * Calculate the chunk size.
+     * @param textDocsInputDataSet
+     * @return Tuple of chunk size and step size.
+     */
+    private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) {
+        int textDocsLength = textDocsInputDataSet.getDocs().size();
+        Map<String, String> parameters = getConnector().getParameters();
+        if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
+            int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
+            // We need to check the parameter on runtime as parameter can be passed into predict request
+            if (stepSize <= 0) {
+                throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
+            } else {
+                boolean isDivisible = textDocsLength % stepSize == 0;
+                if (isDivisible) {
+                    return Tuple.tuple(textDocsLength / stepSize, stepSize);
                 }
-                tensorOutputs.addAll(tempTensorOutputs);
+                return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
             }
         } else {
-            preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
+            // consider as batch.
+            return Tuple.tuple(1, textDocsLength);
         }
-        return new ModelTensorOutput(tensorOutputs);
     }
 
     default void setScriptService(ScriptService scriptService) {}
@@ -104,7 +130,12 @@ default void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap)
 
     default void setMlGuard(MLGuard mlGuard) {}
 
-    default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTensors> tensorOutputs) {
+    default void preparePayloadAndInvokeRemoteModel(
+        MLInput mlInput,
+        Map<Integer, ModelTensors> tensorOutputs,
+        ExecutionContext countDownLatch,
+        ActionListener<List<ModelTensors>> actionListener
+    ) {
         Connector connector = getConnector();
 
         Map<String, String> parameters = new HashMap<>();
@@ -145,10 +176,16 @@ && getUserRateLimiterMap().get(user.getName()) != null
             if (getMlGuard() != null && !getMlGuard().validate(payload, MLGuard.Type.INPUT)) {
                 throw new IllegalArgumentException("guardrails triggered for user input");
             }
-            invokeRemoteModel(mlInput, parameters, payload, tensorOutputs);
+            invokeRemoteModel(mlInput, parameters, payload, tensorOutputs, countDownLatch, actionListener);
         }
     }
 
-    void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs);
-
+    void invokeRemoteModel(
+        MLInput mlInput,
+        Map<String, String> parameters,
+        String payload,
+        Map<Integer, ModelTensors> tensorOutputs,
+        ExecutionContext countDownLatch,
+        ActionListener<List<ModelTensors>> actionListener
+    );
 }
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java
index 8774bcc40c..5828395641 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java
@@ -10,6 +10,7 @@
 import org.opensearch.client.Client;
 import org.opensearch.cluster.service.ClusterService;
 import org.opensearch.common.util.TokenBucket;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.ml.common.FunctionName;
 import org.opensearch.ml.common.MLModel;
@@ -18,6 +19,7 @@
 import org.opensearch.ml.common.input.MLInput;
 import org.opensearch.ml.common.model.MLGuard;
 import org.opensearch.ml.common.output.MLOutput;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.ml.engine.MLEngineClassLoader;
 import org.opensearch.ml.engine.Predictable;
 import org.opensearch.ml.engine.annotation.Function;
@@ -55,18 +57,22 @@ public MLOutput predict(MLInput mlInput, MLModel model) {
     }
 
     @Override
-    public MLOutput predict(MLInput mlInput) {
+    public void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
         if (!isModelReady()) {
-            throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy");
+            actionListener
+                .onFailure(
+                    new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy")
+                );
+            return;
         }
         try {
-            return connectorExecutor.executePredict(mlInput);
+            connectorExecutor.executePredict(mlInput, actionListener);
         } catch (RuntimeException e) {
             log.error("Failed to call remote model.", e);
-            throw e;
+            actionListener.onFailure(e);
         } catch (Throwable e) {
             log.error("Failed to call remote model.", e);
-            throw new MLException(e);
+            actionListener.onFailure(new MLException(e));
         }
     }
 
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java
index c981ebc184..339523b313 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java
@@ -8,84 +8,69 @@
 import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
+import java.security.AccessController;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
+import java.time.Duration;
 import java.util.Arrays;
-
-import org.apache.commons.lang3.math.NumberUtils;
-import org.apache.http.HttpHost;
-import org.apache.http.HttpRequest;
-import org.apache.http.HttpResponse;
-import org.apache.http.client.config.RequestConfig;
-import org.apache.http.conn.UnsupportedSchemeException;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.impl.client.HttpClientBuilder;
-import org.apache.http.impl.client.LaxRedirectStrategy;
-import org.apache.http.impl.conn.DefaultSchemePortResolver;
-import org.apache.http.protocol.HttpContext;
-
-import com.google.common.annotations.VisibleForTesting;
+import java.util.Locale;
 
 import lombok.extern.log4j.Log4j2;
+import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
+import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
 
 @Log4j2
 public class MLHttpClientFactory {
 
-    public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
-        return createHttpClient(connectionTimeout, readTimeout, maxConnections);
-    }
-
-    private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
-        HttpClientBuilder builder = HttpClientBuilder.create();
-
-        // Only allow HTTP and HTTPS schemes
-        builder.setSchemePortResolver(new DefaultSchemePortResolver() {
-            @Override
-            public int resolve(HttpHost host) throws UnsupportedSchemeException {
-                validateSchemaAndPort(host);
-                return super.resolve(host);
-            }
-        });
-
-        builder.setDnsResolver(MLHttpClientFactory::validateIp);
-
-        builder.setRedirectStrategy(new LaxRedirectStrategy() {
-            @Override
-            public boolean isRedirected(HttpRequest request, HttpResponse response, HttpContext context) {
-                // Do not follow redirects
-                return false;
-            }
-        });
-        builder.setMaxConnTotal(maxConnections);
-        builder.setMaxConnPerRoute(maxConnections);
-        RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(connectionTimeout).setSocketTimeout(readTimeout).build();
-        builder.setDefaultRequestConfig(requestConfig);
-        return builder.build();
+    public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
+        try {
+            return AccessController
+                .doPrivileged(
+                    (PrivilegedExceptionAction<SdkAsyncHttpClient>) () -> NettyNioAsyncHttpClient
+                        .builder()
+                        .connectionTimeout(connectionTimeout)
+                        .readTimeout(readTimeout)
+                        .maxConcurrency(maxConnections)
+                        .build()
+                );
+        } catch (PrivilegedActionException e) {
+            return null;
+        }
     }
 
-    @VisibleForTesting
-    protected static void validateSchemaAndPort(HttpHost host) {
-        String scheme = host.getSchemeName();
-        if ("http".equalsIgnoreCase(scheme) || "https".equalsIgnoreCase(scheme)) {
-            String[] hostNamePort = host.getHostName().split(":");
-            if (hostNamePort.length > 1 && NumberUtils.isDigits(hostNamePort[1])) {
-                int port = Integer.parseInt(hostNamePort[1]);
-                if (port < 0 || port > 65536) {
-                    log.error("Remote inference port out of range: " + port);
-                    throw new IllegalArgumentException("Port out of range: " + port);
-                }
+    /**
+     * Validate the input parameters, such as protocol, host and port.
+     * @param protocol The protocol supported in remote inference, currently only http and https are supported.
+     * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
+     * @param port The port number of the remote inference server, port number must be in range [0, 65536].
+     * @throws UnknownHostException
+     */
+    public static void validate(String protocol, String host, int port) throws UnknownHostException {
+        if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
+            log.error("Remote inference protocol is not http or https: " + protocol);
+            throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
+        }
+        // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
+        if (port == -1) {
+            if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
+                port = 80;
+            } else {
+                port = 443;
             }
-        } else {
-            log.error("Remote inference scheme not supported: " + scheme);
-            throw new IllegalArgumentException("Unsupported scheme: " + scheme);
         }
+        if (port < 0 || port > 65536) {
+            log.error("Remote inference port out of range: " + port);
+            throw new IllegalArgumentException("Port out of range: " + port);
+        }
+        validateIp(host);
     }
 
-    protected static InetAddress[] validateIp(String hostName) throws UnknownHostException {
+    private static void validateIp(String hostName) throws UnknownHostException {
         InetAddress[] addresses = InetAddress.getAllByName(hostName);
         if (hasPrivateIpAddress(addresses)) {
             log.error("Remote inference host name has private ip address: " + hostName);
-            throw new IllegalArgumentException(hostName);
+            throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
         }
-        return addresses;
     }
 
     private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java
index 69df5b03ae..5e1a9dfacb 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java
@@ -5,58 +5,48 @@
 
 package org.opensearch.ml.engine.algorithms.remote;
 
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
+import static org.junit.Assert.assertEquals;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.when;
 import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
 import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
 import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
 import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
 
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.io.InputStream;
+import java.lang.reflect.Field;
 import java.util.Arrays;
 import java.util.Map;
-import java.util.Optional;
 
-import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
-import org.opensearch.OpenSearchStatusException;
 import org.opensearch.client.Client;
 import org.opensearch.common.settings.Settings;
 import org.opensearch.common.util.concurrent.ThreadContext;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.FunctionName;
 import org.opensearch.ml.common.connector.AwsConnector;
 import org.opensearch.ml.common.connector.Connector;
 import org.opensearch.ml.common.connector.ConnectorAction;
-import org.opensearch.ml.common.connector.ConnectorClientConfig;
 import org.opensearch.ml.common.connector.MLPreProcessFunction;
 import org.opensearch.ml.common.dataset.MLInputDataset;
 import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
 import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
-import org.opensearch.ml.common.output.model.ModelTensorOutput;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.ml.engine.encryptor.Encryptor;
 import org.opensearch.ml.engine.encryptor.EncryptorImpl;
-import org.opensearch.script.ScriptService;
 import org.opensearch.threadpool.ThreadPool;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 
-import software.amazon.awssdk.http.AbortableInputStream;
-import software.amazon.awssdk.http.ExecutableHttpRequest;
-import software.amazon.awssdk.http.HttpExecuteResponse;
-import software.amazon.awssdk.http.SdkHttpClient;
-import software.amazon.awssdk.http.SdkHttpResponse;
-
 public class AwsConnectorExecutorTest {
 
     @Rule
@@ -73,16 +63,7 @@ public class AwsConnectorExecutorTest {
     ThreadContext threadContext;
 
     @Mock
-    ScriptService scriptService;
-
-    @Mock
-    SdkHttpClient httpClient;
-
-    @Mock
-    ExecutableHttpRequest httpRequest;
-
-    @Mock
-    HttpExecuteResponse response;
+    ActionListener<MLTaskResponse> actionListener;
 
     Encryptor encryptor;
 
@@ -100,7 +81,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
+            .url("http://openai.com/mock")
             .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         AwsConnector
@@ -113,21 +94,12 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
     }
 
     @Test
-    public void executePredict_RemoteInferenceInput_NullResponse() throws IOException {
-        exceptionRule.expect(OpenSearchStatusException.class);
-        exceptionRule.expectMessage("No response from model");
-        when(response.responseBody()).thenReturn(Optional.empty());
-        when(httpRequest.call()).thenReturn(response);
-        SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
-        when(httpResponse.statusCode()).thenReturn(200);
-        when(response.httpResponse()).thenReturn(httpResponse);
-        when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
-
+    public void executePredict_RemoteInferenceInput_EmptyIpAddress() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
+            .url("http:///mock")
             .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         Map<String, String> credential = ImmutableMap
@@ -143,7 +115,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
             .actions(Arrays.asList(predictAction))
             .build();
         connector.decrypt((c) -> encryptor.decrypt(c));
-        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
+        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
         Settings settings = Settings.builder().build();
         threadContext = new ThreadContext(settings);
         when(executor.getClient()).thenReturn(client);
@@ -151,29 +123,22 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
         when(threadPool.getThreadContext()).thenReturn(threadContext);
 
         MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
-        executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
+        executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
+        ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
+        Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
+        assert exceptionCaptor.getValue() instanceof NullPointerException;
+        assertEquals("host must not be null.", exceptionCaptor.getValue().getMessage());
     }
 
     @Test
-    public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException {
-        exceptionRule.expect(OpenSearchStatusException.class);
-        exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}");
-        String jsonString = "{\"message\":\"The security token included in the request is invalid\"}";
-        InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
-        AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
-        when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
-        when(httpRequest.call()).thenReturn(response);
-        SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
-        when(httpResponse.statusCode()).thenReturn(403);
-        when(response.httpResponse()).thenReturn(httpResponse);
-        when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
-
+    public void executePredict_TextDocsInferenceInput() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .requestBody("{\"input\": \"${parameters.input}\"}")
+            .url("http://openai.com/mock")
+            .requestBody("{\"input\": ${parameters.input}}")
+            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
             .build();
         Map<String, String> credential = ImmutableMap
             .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
@@ -188,39 +153,32 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio
             .actions(Arrays.asList(predictAction))
             .build();
         connector.decrypt((c) -> encryptor.decrypt(c));
-        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
+        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
         Settings settings = Settings.builder().build();
         threadContext = new ThreadContext(settings);
         when(executor.getClient()).thenReturn(client);
         when(client.threadPool()).thenReturn(threadPool);
         when(threadPool.getThreadContext()).thenReturn(threadContext);
 
-        MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
-        executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
+        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
+        executor
+            .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
     }
 
     @Test
-    public void executePredict_RemoteInferenceInput() throws IOException {
-        String jsonString = "{\"key\":\"value\"}";
-        InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
-        AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
-        when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
-        SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
-        when(httpResponse.statusCode()).thenReturn(200);
-        when(response.httpResponse()).thenReturn(httpResponse);
-        when(httpRequest.call()).thenReturn(response);
-        when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
-
+    public void executePredict_TextDocsInferenceInput_withStepSize() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .requestBody("{\"input\": \"${parameters.input}\"}")
+            .url("http://openai.com/mock")
+            .requestBody("{\"input\": ${parameters.input}}")
+            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
             .build();
         Map<String, String> credential = ImmutableMap
             .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
-        Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
+        Map<String, String> parameters = ImmutableMap
+            .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "2");
         Connector connector = AwsConnector
             .awsConnectorBuilder()
             .name("test connector")
@@ -231,42 +189,30 @@ public void executePredict_RemoteInferenceInput() throws IOException {
             .actions(Arrays.asList(predictAction))
             .build();
         connector.decrypt((c) -> encryptor.decrypt(c));
-        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
+        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
         Settings settings = Settings.builder().build();
         threadContext = new ThreadContext(settings);
         when(executor.getClient()).thenReturn(client);
         when(client.threadPool()).thenReturn(threadPool);
         when(threadPool.getThreadContext()).thenReturn(threadContext);
 
-        MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
-        Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
-        Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
+        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
+        executor
+            .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
+
+        MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build();
+        executor
+            .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener);
     }
 
     @Test
-    public void executePredict_TextDocsInferenceInput() throws IOException {
-        String jsonString = "{\"key\":\"value\"}";
-        InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
-        AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
-        when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
-        when(httpRequest.call()).thenReturn(response);
-        SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
-        when(httpResponse.statusCode()).thenReturn(200);
-        when(response.httpResponse()).thenReturn(httpResponse);
-        when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
-
+    public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .requestBody("{\"input\": ${parameters.input}}")
-            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
+            .url("http://openai.com/mock")
+            .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         Map<String, String> credential = ImmutableMap
             .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
@@ -281,36 +227,37 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
             .actions(Arrays.asList(predictAction))
             .build();
         connector.decrypt((c) -> encryptor.decrypt(c));
-        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
+        AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector);
+        Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient");
+        httpClientField.setAccessible(true);
+        httpClientField.set(executor0, null);
+        AwsConnectorExecutor executor = spy(executor0);
         Settings settings = Settings.builder().build();
         threadContext = new ThreadContext(settings);
         when(executor.getClient()).thenReturn(client);
         when(client.threadPool()).thenReturn(threadPool);
         when(threadPool.getThreadContext()).thenReturn(threadContext);
 
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
-        Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
-        Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
+        MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
+        executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
+        ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
+        Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
+        assert exceptionCaptor.getValue() instanceof NullPointerException;
     }
 
     @Test
-    public void test_initialize() {
+    public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
+            .url("http://openai.com/mock")
             .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         Map<String, String> credential = ImmutableMap
             .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
-        Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
-        ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(20, 30000, 30000);
+        Map<String, String> parameters = ImmutableMap
+            .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "-1");
         Connector connector = AwsConnector
             .awsConnectorBuilder()
             .name("test connector")
@@ -319,13 +266,20 @@ public void test_initialize() {
             .parameters(parameters)
             .credential(credential)
             .actions(Arrays.asList(predictAction))
-            .connectorClientConfig(httpClientConfig)
             .build();
         connector.decrypt((c) -> encryptor.decrypt(c));
-        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
-        Assert.assertEquals(20, executor.getConnector().getConnectorClientConfig().getMaxConnections().intValue());
-        Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getConnectionTimeout().intValue());
-        Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getReadTimeout().intValue());
+        AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
+        Settings settings = Settings.builder().build();
+        threadContext = new ThreadContext(settings);
+        when(executor.getClient()).thenReturn(client);
+        when(client.threadPool()).thenReturn(threadPool);
+        when(threadPool.getThreadContext()).thenReturn(threadContext);
 
+        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
+        executor
+            .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
+        ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
+        Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
+        assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
     }
 }
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java
index 178d9aa722..cb7f8a4fe8 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java
@@ -168,7 +168,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
     public void processOutput_NullResponse() throws IOException {
         exceptionRule.expect(IllegalArgumentException.class);
         exceptionRule.expectMessage("model response is null");
-        ConnectorUtils.processOutput(null, null, null, null);
+        ConnectorUtils.processOutput(null, null, null, null, null);
     }
 
     @Test
@@ -192,7 +192,7 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
             .build();
         String modelResponse =
             "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
-        ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
+        ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
         Assert.assertEquals(1, tensors.getMlModelTensors().size());
         Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
         Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size());
@@ -224,7 +224,7 @@ public void processOutput_PostprocessFunction() throws IOException {
             .build();
         String modelResponse =
             "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
-        ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
+        ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
         Assert.assertEquals(1, tensors.getMlModelTensors().size());
         Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
         Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap());
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java
index 122402fea3..f4cd6715cf 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java
@@ -5,46 +5,34 @@
 
 package org.opensearch.ml.engine.algorithms.remote;
 
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.when;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
-import java.io.IOException;
+import java.lang.reflect.Field;
 import java.util.Arrays;
-import java.util.Map;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
 
-import org.apache.http.HttpEntity;
-import org.apache.http.ProtocolVersion;
-import org.apache.http.StatusLine;
-import org.apache.http.client.methods.CloseableHttpResponse;
-import org.apache.http.entity.StringEntity;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.message.BasicStatusLine;
-import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
-import org.opensearch.OpenSearchStatusException;
-import org.opensearch.client.Client;
-import org.opensearch.common.settings.Settings;
-import org.opensearch.common.util.concurrent.ThreadContext;
-import org.opensearch.ingest.TestTemplateService;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.FunctionName;
 import org.opensearch.ml.common.connector.Connector;
 import org.opensearch.ml.common.connector.ConnectorAction;
 import org.opensearch.ml.common.connector.HttpConnector;
-import org.opensearch.ml.common.connector.MLPostProcessFunction;
-import org.opensearch.ml.common.connector.MLPreProcessFunction;
 import org.opensearch.ml.common.dataset.MLInputDataset;
-import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
 import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
-import org.opensearch.ml.common.output.model.ModelTensorOutput;
-import org.opensearch.script.ScriptService;
-import org.opensearch.threadpool.ThreadPool;
+import org.opensearch.ml.common.output.model.ModelTensors;
 
 import com.google.common.collect.ImmutableMap;
 
@@ -53,23 +41,7 @@ public class HttpJsonConnectorExecutorTest {
     public ExpectedException exceptionRule = ExpectedException.none();
 
     @Mock
-    ThreadPool threadPool;
-
-    @Mock
-    ScriptService scriptService;
-
-    @Mock
-    CloseableHttpClient httpClient;
-
-    @Mock
-    Client client;
-
-    @Mock
-    CloseableHttpResponse response;
-
-    Settings settings;
-
-    ThreadContext threadContext;
+    private ActionListener<List<ModelTensors>> actionListener;
 
     @Before
     public void setUp() {
@@ -78,13 +50,11 @@ public void setUp() {
 
     @Test
     public void invokeRemoteModel_WrongHttpMethod() {
-        exceptionRule.expect(IllegalArgumentException.class);
-        exceptionRule.expectMessage("unsupported http method");
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("wrong_method")
-            .url("http://test.com/mock")
+            .url("http://openai.com/mock")
             .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         Connector connector = HttpConnector
@@ -94,17 +64,20 @@ public void invokeRemoteModel_WrongHttpMethod() {
             .protocol("http")
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector, httpClient);
-        executor.invokeRemoteModel(null, null, null, null);
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        executor.invokeRemoteModel(null, null, null, null, null, actionListener);
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
+        Mockito.verify(actionListener, times(1)).onFailure(captor.capture());
+        assertEquals("unsupported http method", captor.getValue().getMessage());
     }
 
     @Test
-    public void executePredict_RemoteInferenceInput() throws IOException {
+    public void invokeRemoteModel_invalidIpAddress() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
+            .url("http://127.0.0.1/mock")
             .requestBody("{\"input\": \"${parameters.input}\"}")
             .build();
         Connector connector = HttpConnector
@@ -114,44 +87,31 @@ public void executePredict_RemoteInferenceInput() throws IOException {
             .protocol("http")
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        when(httpClient.execute(any())).thenReturn(response);
-        HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
-        when(response.getEntity()).thenReturn(entity);
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
-        MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
-        Assert
-            .assertEquals(
-                "test result",
-                modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        executor
+            .invokeRemoteModel(
+                createMLInput(),
+                new HashMap<>(),
+                "{\"input\": \"hello world\"}",
+                new HashMap<>(),
+                new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
+                actionListener
             );
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
+        Mockito.verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue() instanceof IllegalArgumentException;
+        assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage());
     }
 
     @Test
-    public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException {
+    public void invokeRemoteModel_Empty_payload() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .requestBody("{\"input\": ${parameters.input}}")
+            .url("http://openai.com/mock")
+            .requestBody("")
             .build();
-        when(httpClient.execute(any())).thenReturn(response);
-        HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
-        when(response.getEntity()).thenReturn(entity);
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
         Connector connector = HttpConnector
             .builder()
             .name("test connector")
@@ -159,41 +119,31 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
             .protocol("http")
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
-        Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
-        Assert
-            .assertEquals(
-                "test result",
-                modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        executor
+            .invokeRemoteModel(
+                createMLInput(),
+                new HashMap<>(),
+                null,
+                new HashMap<>(),
+                new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
+                actionListener
             );
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
+        Mockito.verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue() instanceof IllegalArgumentException;
+        assertEquals("Content length is 0. Aborting request to remote model", captor.getValue().getMessage());
     }
 
     @Test
-    public void executePredict_TextDocsInput_LimitExceed() throws IOException {
-        exceptionRule.expect(OpenSearchStatusException.class);
-        exceptionRule.expectMessage("{\"message\": \"Too many requests\"}");
+    public void invokeRemoteModel_get_request() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
-            .method("POST")
-            .url("http://test.com/mock")
-            .requestBody("{\"input\": ${parameters.input}}")
+            .method("GET")
+            .url("http://openai.com/mock")
+            .requestBody("")
             .build();
-        when(httpClient.execute(any())).thenReturn(response);
-        HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}");
-        when(response.getEntity()).thenReturn(entity);
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
         Connector connector = HttpConnector
             .builder()
             .name("test connector")
@@ -201,233 +151,82 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException {
             .protocol("http")
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
-        executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        executor
+            .invokeRemoteModel(
+                createMLInput(),
+                new HashMap<>(),
+                null,
+                new HashMap<>(),
+                new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
+                actionListener
+            );
     }
 
     @Test
-    public void executePredict_TextDocsInput() throws IOException {
-        String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
-        String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
-        when(scriptService.compile(any(), any()))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
-
+    public void invokeRemoteModel_post_request() {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
-            .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
-            .requestBody("{\"input\": ${parameters.input}}")
+            .url("http://openai.com/mock")
+            .requestBody("hello world")
             .build();
-        HttpConnector connector = HttpConnector
+        Connector connector = HttpConnector
             .builder()
             .name("test connector")
             .version("1")
             .protocol("http")
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        executor.setScriptService(scriptService);
-        when(httpClient.execute(any())).thenReturn(response);
-        String modelResponse = "{\n"
-            + "    \"object\": \"list\",\n"
-            + "    \"data\": [\n"
-            + "        {\n"
-            + "            \"object\": \"embedding\",\n"
-            + "            \"index\": 0,\n"
-            + "            \"embedding\": [\n"
-            + "                -0.014555434,\n"
-            + "                -0.002135904,\n"
-            + "                0.0035105038\n"
-            + "            ]\n"
-            + "        },\n"
-            + "        {\n"
-            + "            \"object\": \"embedding\",\n"
-            + "            \"index\": 1,\n"
-            + "            \"embedding\": [\n"
-            + "                -0.014555434,\n"
-            + "                -0.002135904,\n"
-            + "                0.0035105038\n"
-            + "            ]\n"
-            + "        }\n"
-            + "    ],\n"
-            + "    \"model\": \"text-embedding-ada-002-v2\",\n"
-            + "    \"usage\": {\n"
-            + "        \"prompt_tokens\": 5,\n"
-            + "        \"total_tokens\": 5\n"
-            + "    }\n"
-            + "}";
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
-        HttpEntity entity = new StringEntity(modelResponse);
-        when(response.getEntity()).thenReturn(entity);
-        when(executor.getConnector()).thenReturn(connector);
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
-        Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert
-            .assertArrayEquals(
-                new Number[] { -0.014555434, -0.002135904, 0.0035105038 },
-                modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()
-            );
-        Assert
-            .assertArrayEquals(
-                new Number[] { -0.014555434, -0.002135904, 0.0035105038 },
-                modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        executor
+            .invokeRemoteModel(
+                createMLInput(),
+                new HashMap<>(),
+                "hello world",
+                new HashMap<>(),
+                new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
+                actionListener
             );
     }
 
     @Test
-    public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException {
-        String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
-        String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
-        when(scriptService.compile(any(), any()))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
-
+    public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException {
         ConnectorAction predictAction = ConnectorAction
             .builder()
             .actionType(ConnectorAction.ActionType.PREDICT)
             .method("POST")
-            .url("http://test.com/mock")
-            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
-            .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
-            .requestBody("{\"input\": ${parameters.input}}")
+            .url("http://openai.com/mock")
+            .requestBody("hello world")
             .build();
-        Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "2");
-        HttpConnector connector = HttpConnector
+        Connector connector = HttpConnector
             .builder()
             .name("test connector")
             .version("1")
             .protocol("http")
-            .parameters(parameters)
             .actions(Arrays.asList(predictAction))
             .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        executor.setScriptService(scriptService);
-        when(httpClient.execute(any())).thenReturn(response);
-        // model takes 2 input docs, but only output 1 embedding
-        String modelResponse = "{\n"
-            + "    \"object\": \"list\",\n"
-            + "    \"data\": [\n"
-            + "        {\n"
-            + "            \"object\": \"embedding\",\n"
-            + "            \"index\": 0,\n"
-            + "            \"embedding\": [\n"
-            + "                -0.014555434,\n"
-            + "                -0.002135904,\n"
-            + "                0.0035105038\n"
-            + "            ]\n"
-            + "        }    ],\n"
-            + "    \"model\": \"text-embedding-ada-002-v2\",\n"
-            + "    \"usage\": {\n"
-            + "        \"prompt_tokens\": 5,\n"
-            + "        \"total_tokens\": 5\n"
-            + "    }\n"
-            + "}";
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
-        HttpEntity entity = new StringEntity(modelResponse);
-        when(response.getEntity()).thenReturn(entity);
-        when(executor.getConnector()).thenReturn(connector);
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
-        Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
-        Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
-        Assert
-            .assertArrayEquals(
-                new Number[] { -0.014555434, -0.002135904, 0.0035105038 },
-                modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()
+        HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
+        Field httpClientField = HttpJsonConnectorExecutor.class.getDeclaredField("httpClient");
+        httpClientField.setAccessible(true);
+        httpClientField.set(executor, null);
+        executor
+            .invokeRemoteModel(
+                createMLInput(),
+                new HashMap<>(),
+                "hello world",
+                new HashMap<>(),
+                new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
+                actionListener
             );
+        ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener, times(1)).onFailure(argumentCaptor.capture());
+        assert argumentCaptor.getValue() instanceof NullPointerException;
     }
 
-    @Test
-    public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException {
-        exceptionRule.expect(IllegalArgumentException.class);
-        exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
-        String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
-        String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
-        when(scriptService.compile(any(), any()))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
-            .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
-
-        ConnectorAction predictAction = ConnectorAction
-            .builder()
-            .actionType(ConnectorAction.ActionType.PREDICT)
-            .method("POST")
-            .url("http://test.com/mock")
-            .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
-            .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
-            .requestBody("{\"input\": ${parameters.input}}")
-            .build();
-        // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException
-        Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "-1");
-        HttpConnector connector = HttpConnector
-            .builder()
-            .name("test connector")
-            .version("1")
-            .protocol("http")
-            .parameters(parameters)
-            .actions(Arrays.asList(predictAction))
-            .build();
-        HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
-        Settings settings = Settings.builder().build();
-        threadContext = new ThreadContext(settings);
-        when(executor.getClient()).thenReturn(client);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(threadContext);
-        executor.setScriptService(scriptService);
-        when(httpClient.execute(any())).thenReturn(response);
-        // model takes 2 input docs, but only output 1 embedding
-        String modelResponse = "{\n"
-            + "    \"object\": \"list\",\n"
-            + "    \"data\": [\n"
-            + "        {\n"
-            + "            \"object\": \"embedding\",\n"
-            + "            \"index\": 0,\n"
-            + "            \"embedding\": [\n"
-            + "                -0.014555434,\n"
-            + "                -0.002135904,\n"
-            + "                0.0035105038\n"
-            + "            ]\n"
-            + "        }    ],\n"
-            + "    \"model\": \"text-embedding-ada-002-v2\",\n"
-            + "    \"usage\": {\n"
-            + "        \"prompt_tokens\": 5,\n"
-            + "        \"total_tokens\": 5\n"
-            + "    }\n"
-            + "}";
-        StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
-        when(response.getStatusLine()).thenReturn(statusLine);
-        HttpEntity entity = new StringEntity(modelResponse);
-        when(response.getEntity()).thenReturn(entity);
-        when(executor.getConnector()).thenReturn(connector);
-        MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
-        ModelTensorOutput modelTensorOutput = executor
-            .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
+    private MLInput createMLInput() {
+        MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
+        return MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
     }
 }
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java
new file mode 100644
index 0000000000..68b5cdeb5f
--- /dev/null
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java
@@ -0,0 +1,422 @@
+/*
+ *
+ *  * Copyright OpenSearch Contributors
+ *  * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.ml.engine.algorithms.remote;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.opensearch.OpenSearchStatusException;
+import org.opensearch.core.action.ActionListener;
+import org.opensearch.ml.common.connector.Connector;
+import org.opensearch.ml.common.connector.ConnectorAction;
+import org.opensearch.ml.common.connector.HttpConnector;
+import org.opensearch.ml.common.connector.MLPostProcessFunction;
+import org.opensearch.ml.common.output.model.ModelTensors;
+import org.opensearch.script.ScriptService;
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscription;
+
+import software.amazon.awssdk.http.HttpStatusCode;
+import software.amazon.awssdk.http.SdkHttpFullResponse;
+import software.amazon.awssdk.http.SdkHttpResponse;
+
+public class MLSdkAsyncHttpResponseHandlerTest {
+    private final ExecutionContext executionContext = new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>());
+    @Mock
+    private ActionListener<List<ModelTensors>> actionListener;
+    @Mock
+    private Map<String, String> parameters;
+    private Map<Integer, ModelTensors> tensorOutputs = new ConcurrentHashMap<>();
+    private Connector connector;
+
+    private Connector noProcessFunctionConnector;
+
+    @Mock
+    private SdkHttpFullResponse sdkHttpResponse;
+    @Mock
+    private ScriptService scriptService;
+
+    private MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler;
+
+    MLSdkAsyncHttpResponseHandler.MLResponseSubscriber responseSubscriber;
+
+    @Before
+    public void setup() {
+        MockitoAnnotations.openMocks(this);
+        when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.OK);
+        ConnectorAction predictAction = ConnectorAction
+            .builder()
+            .actionType(ConnectorAction.ActionType.PREDICT)
+            .method("POST")
+            .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING)
+            .url("http://test.com/mock")
+            .requestBody("{\"input\": \"${parameters.input}\"}")
+            .build();
+        connector = HttpConnector
+            .builder()
+            .name("test connector")
+            .version("1")
+            .protocol("http")
+            .actions(Arrays.asList(predictAction))
+            .build();
+
+        ConnectorAction noProcessFunctionPredictAction = ConnectorAction
+            .builder()
+            .actionType(ConnectorAction.ActionType.PREDICT)
+            .method("POST")
+            .url("http://test.com/mock")
+            .requestBody("{\"input\": \"${parameters.input}\"}")
+            .build();
+        noProcessFunctionConnector = HttpConnector
+            .builder()
+            .name("test connector")
+            .version("1")
+            .protocol("http")
+            .actions(Arrays.asList(noProcessFunctionPredictAction))
+            .build();
+        mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
+            executionContext,
+            actionListener,
+            parameters,
+            tensorOutputs,
+            connector,
+            scriptService,
+            null
+        );
+        responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber();
+    }
+
+    @Test
+    public void test_OnHeaders() {
+        mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
+        assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 200;
+    }
+
+    @Test
+    public void test_OnStream_with_postProcessFunction_bedRock() {
+        String response = "{\n"
+            + "    \"embedding\": [\n"
+            + "        0.46484375,\n"
+            + "        -0.017822266,\n"
+            + "        0.17382812,\n"
+            + "        0.10595703,\n"
+            + "        0.875,\n"
+            + "        0.19140625,\n"
+            + "        -0.36914062,\n"
+            + "        -0.0011978149\n"
+            + "    ]\n"
+            + "}";
+        Publisher<ByteBuffer> stream = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        test_OnHeaders(); // set the status code to non-null
+        mlSdkAsyncHttpResponseHandler.onStream(stream);
+        ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
+        verify(actionListener).onResponse(captor.capture());
+        assert captor.getValue().size() == 1;
+        assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8;
+    }
+
+    @Test
+    public void test_OnStream_without_postProcessFunction() {
+        Publisher<ByteBuffer> stream = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap("{\"key\": \"hello world\"}".getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        MLSdkAsyncHttpResponseHandler noProcessFunctionMlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
+            executionContext,
+            actionListener,
+            parameters,
+            tensorOutputs,
+            noProcessFunctionConnector,
+            scriptService,
+            null
+        );
+        noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
+        noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream);
+        ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
+        verify(actionListener).onResponse(captor.capture());
+        assert captor.getValue().size() == 1;
+        assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("key").equals("hello world");
+    }
+
+    @Test
+    public void test_onError() {
+        test_OnHeaders(); // set the status code to non-null
+        mlSdkAsyncHttpResponseHandler.onError(new RuntimeException("runtime exception"));
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener).onFailure(captor.capture());
+        assert captor.getValue() instanceof OpenSearchStatusException;
+        assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception");
+    }
+
+    @Test
+    public void test_MLSdkAsyncHttpResponseHandler_onError() {
+        mlSdkAsyncHttpResponseHandler.onError(new Exception("error"));
+    }
+
+    @Test
+    public void test_onSubscribe() {
+        Subscription subscription = mock(Subscription.class);
+        responseSubscriber.onSubscribe(subscription);
+    }
+
+    @Test
+    public void test_onNext() {
+        test_onSubscribe();// set the subscription to non-null.
+        responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes()));
+        assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world");
+    }
+
+    @Test
+    public void test_MLResponseSubscriber_onError() {
+        SdkHttpFullResponse response = mock(SdkHttpFullResponse.class);
+        when(response.statusCode()).thenReturn(500);
+        mlSdkAsyncHttpResponseHandler.onHeaders(response);
+        responseSubscriber.onError(new Exception("error"));
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue() instanceof OpenSearchStatusException;
+        assert captor.getValue().getMessage().equals("No response from model");
+    }
+
+    @Test
+    public void test_onComplete_success() {
+        String response = "{\n"
+            + "    \"embedding\": [\n"
+            + "        0.46484375,\n"
+            + "        -0.017822266,\n"
+            + "        0.17382812,\n"
+            + "        0.10595703,\n"
+            + "        0.875,\n"
+            + "        0.19140625,\n"
+            + "        -0.36914062,\n"
+            + "        -0.0011978149\n"
+            + "    ]\n"
+            + "}";
+        mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
+        Publisher<ByteBuffer> stream = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler.onStream(stream);
+        ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
+        verify(actionListener).onResponse(captor.capture());
+        assert captor.getValue().size() == 1;
+        assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8;
+    }
+
+    @Test
+    public void test_onComplete_partial_success_exceptionSecond() {
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+        String response1 = "{\n"
+            + "    \"embedding\": [\n"
+            + "        0.46484375,\n"
+            + "        -0.017822266,\n"
+            + "        0.17382812,\n"
+            + "        0.10595703,\n"
+            + "        0.875,\n"
+            + "        0.19140625,\n"
+            + "        -0.36914062,\n"
+            + "        -0.0011978149\n"
+            + "    ]\n"
+            + "}";
+        String response2 = "Model current status is: FAILED";
+        CountDownLatch count = new CountDownLatch(2);
+        MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
+            new ExecutionContext(0, count, exceptionHolder),
+            actionListener,
+            parameters,
+            tensorOutputs,
+            connector,
+            scriptService,
+            null
+        );
+        MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
+            new ExecutionContext(1, count, exceptionHolder),
+            actionListener,
+            parameters,
+            tensorOutputs,
+            connector,
+            scriptService,
+            null
+        );
+        SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
+        when(sdkHttpResponse1.statusCode()).thenReturn(200);
+        mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
+        Publisher<ByteBuffer> stream1 = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response1.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler1.onStream(stream1);
+
+        SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
+        when(sdkHttpResponse2.statusCode()).thenReturn(500);
+        mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
+        Publisher<ByteBuffer> stream2 = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response2.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler2.onStream(stream2);
+        ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
+        verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
+        assert captor.getValue().status().getStatus() == 500;
+    }
+
+    @Test
+    public void test_onComplete_partial_success_exceptionFirst() {
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+        String response1 = "{\n"
+            + "    \"embedding\": [\n"
+            + "        0.46484375,\n"
+            + "        -0.017822266,\n"
+            + "        0.17382812,\n"
+            + "        0.10595703,\n"
+            + "        0.875,\n"
+            + "        0.19140625,\n"
+            + "        -0.36914062,\n"
+            + "        -0.0011978149\n"
+            + "    ]\n"
+            + "}";
+        String response2 = "Model current status is: FAILED";
+        CountDownLatch count = new CountDownLatch(2);
+        MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
+            new ExecutionContext(0, count, exceptionHolder),
+            actionListener,
+            parameters,
+            tensorOutputs,
+            connector,
+            scriptService,
+            null
+        );
+        MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
+            new ExecutionContext(1, count, exceptionHolder),
+            actionListener,
+            parameters,
+            tensorOutputs,
+            connector,
+            scriptService,
+            null
+        );
+
+        SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
+        when(sdkHttpResponse2.statusCode()).thenReturn(500);
+        mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
+        Publisher<ByteBuffer> stream2 = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response2.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler2.onStream(stream2);
+
+        SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
+        when(sdkHttpResponse1.statusCode()).thenReturn(200);
+        mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
+        Publisher<ByteBuffer> stream1 = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(response1.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler1.onStream(stream1);
+        ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
+        verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
+        assert captor.getValue().status().getStatus() == 500;
+    }
+
+    @Test
+    public void test_onComplete_empty_response_body() {
+        mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
+        Publisher<ByteBuffer> stream = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap("".getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler.onStream(stream);
+        ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
+        verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue().getMessage().equals("No response from model");
+    }
+
+    @Test
+    public void test_onComplete_error_http_status() {
+        String error = "{\"message\": \"runtime error\"}";
+        SdkHttpResponse response = mock(SdkHttpFullResponse.class);
+        when(response.statusCode()).thenReturn(HttpStatusCode.INTERNAL_SERVER_ERROR);
+        mlSdkAsyncHttpResponseHandler.onHeaders(response);
+        Publisher<ByteBuffer> stream = s -> {
+            try {
+                s.onSubscribe(mock(Subscription.class));
+                s.onNext(ByteBuffer.wrap(error.getBytes()));
+                s.onComplete();
+            } catch (Throwable e) {
+                s.onError(e);
+            }
+        };
+        mlSdkAsyncHttpResponseHandler.onStream(stream);
+        ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener, times(1)).onFailure(captor.capture());
+        assert captor.getValue() instanceof OpenSearchStatusException;
+        System.out.println(captor.getValue().getMessage());
+        assert captor.getValue().getMessage().contains("runtime error");
+    }
+}
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java
index 149848327b..c14b329586 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java
@@ -5,9 +5,12 @@
 
 package org.opensearch.ml.engine.algorithms.remote;
 
+import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import java.util.Arrays;
@@ -18,14 +21,17 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.MLModel;
 import org.opensearch.ml.common.connector.Connector;
 import org.opensearch.ml.common.connector.ConnectorAction;
 import org.opensearch.ml.common.connector.ConnectorProtocols;
 import org.opensearch.ml.common.connector.HttpConnector;
 import org.opensearch.ml.common.input.MLInput;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.ml.engine.encryptor.Encryptor;
 import org.opensearch.ml.engine.encryptor.EncryptorImpl;
 
@@ -60,20 +66,36 @@ public void predict_ModelNotDeployed() {
     }
 
     @Test
-    public void predict_NullConnectorExecutor() {
-        exceptionRule.expect(RuntimeException.class);
-        exceptionRule.expectMessage("Model not ready yet");
+    public void test_predict_throw_IllegalStateException() {
+        exceptionRule.expect(IllegalStateException.class);
+        exceptionRule.expectMessage("Method is not implemented");
         remoteModel.predict(mlInput);
     }
 
+    @Test
+    public void predict_NullConnectorExecutor() {
+        ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
+        remoteModel.asyncPredict(mlInput, actionListener);
+        ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener).onFailure(argumentCaptor.capture());
+        assert argumentCaptor.getValue() instanceof RuntimeException;
+        assertEquals(
+            "Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy",
+            argumentCaptor.getValue().getMessage()
+        );
+    }
+
     @Test
     public void predict_ModelDeployed_WrongInput() {
-        exceptionRule.expect(RuntimeException.class);
-        exceptionRule.expectMessage("pre_process_function not defined in connector");
         Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
         when(mlModel.getConnector()).thenReturn(connector);
         remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
-        remoteModel.predict(mlInput);
+        ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
+        remoteModel.asyncPredict(mlInput, actionListener);
+        ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
+        verify(actionListener).onFailure(argumentCaptor.capture());
+        assert argumentCaptor.getValue() instanceof RuntimeException;
+        assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage());
     }
 
     @Test
@@ -105,8 +127,8 @@ public void initModel_WithHeader() {
         Assert.assertNotNull(executor);
         Assert.assertNull(decryptedHeaders);
         Assert.assertNotNull(executor.getConnector().getDecryptedHeaders());
-        Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
-        Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));
+        assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
+        assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));
 
         remoteModel.close();
         Assert.assertNull(remoteModel.getConnectorExecutor());
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
index 179d11d2a1..138faf65e1 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
@@ -8,6 +8,8 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS;
 import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS;
 import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE;
@@ -30,7 +32,9 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.mockito.ArgumentCaptor;
 import org.opensearch.ResourceNotFoundException;
+import org.opensearch.core.action.ActionListener;
 import org.opensearch.ml.common.FunctionName;
 import org.opensearch.ml.common.MLModel;
 import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
@@ -627,6 +631,16 @@ public void predict_BeforeInitingModel() {
         textEmbeddingDenseModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), model);
     }
 
+    @Test
+    public void test_async_inference() {
+        ArgumentCaptor<IllegalStateException> captor = ArgumentCaptor.forClass(IllegalStateException.class);
+        ActionListener actionListener = mock(ActionListener.class);
+        textEmbeddingDenseModel
+            .asyncPredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
+        verify(actionListener).onFailure(captor.capture());
+        assert captor.getValue().getMessage().equals("Method is not implemented");
+    }
+
     @After
     public void tearDown() {
         FileUtils.deleteFileQuietly(mlCachePath);
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java
index fa22837249..d1d9b42dcc 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java
@@ -7,98 +7,91 @@
 
 import static org.junit.Assert.assertNotNull;
 
-import java.net.UnknownHostException;
+import java.time.Duration;
 
-import org.apache.http.HttpHost;
-import org.apache.http.impl.client.CloseableHttpClient;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 
+import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
+
 public class MLHttpClientFactoryTests {
 
     @Rule
     public ExpectedException expectedException = ExpectedException.none();
 
     @Test
-    public void test_getCloseableHttpClient_success() {
-        CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(1000, 1000, 30);
+    public void test_getSdkAsyncHttpClient_success() {
+        SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100);
         assertNotNull(client);
     }
 
     @Test
-    public void test_validateIp_validIp_noException() throws UnknownHostException {
-        MLHttpClientFactory.validateIp("api.openai.com");
-    }
-
-    @Test
-    public void test_validateIp_invalidIp_throwException() throws UnknownHostException {
-        expectedException.expect(UnknownHostException.class);
-        MLHttpClientFactory.validateIp("www.makesureitisaunknownhost.com");
+    public void test_validateIp_validIp_noException() throws Exception {
+        MLHttpClientFactory.validate("http", "api.openai.com", 80);
     }
 
     @Test
-    public void test_validateIp_privateIp_throwException() throws UnknownHostException {
-        expectedException.expect(IllegalArgumentException.class);
-        MLHttpClientFactory.validateIp("localhost");
-    }
-
-    @Test
-    public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException {
+    public void test_validateIp_rarePrivateIp_throwException() throws Exception {
         try {
-            MLHttpClientFactory.validateIp("0177.1");
+            MLHttpClientFactory.validate("http", "0254.020.00.01", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
 
         try {
-            MLHttpClientFactory.validateIp("172.1048577");
+            MLHttpClientFactory.validate("http", "172.1048577", 80);
+        } catch (Exception e) {
+            assertNotNull(e);
+        }
+
+        try {
+            MLHttpClientFactory.validate("http", "2886729729", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
 
         try {
-            MLHttpClientFactory.validateIp("2886729729");
+            MLHttpClientFactory.validate("http", "192.11010049", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
 
         try {
-            MLHttpClientFactory.validateIp("192.11010049");
+            MLHttpClientFactory.validate("http", "3232300545", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
 
         try {
-            MLHttpClientFactory.validateIp("3232300545");
+            MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
 
         try {
-            MLHttpClientFactory.validateIp("0:0:0:0:0:ffff:127.0.0.1");
+            MLHttpClientFactory.validate("http", "153.24.76.232", 80);
         } catch (IllegalArgumentException e) {
             assertNotNull(e);
         }
     }
 
     @Test
-    public void test_validateSchemaAndPort_success() {
-        HttpHost httpHost = new HttpHost("api.openai.com", 8080, "https");
-        MLHttpClientFactory.validateSchemaAndPort(httpHost);
+    public void test_validateSchemaAndPort_success() throws Exception {
+        MLHttpClientFactory.validate("http", "api.openai.com", 80);
     }
 
     @Test
-    public void test_validateSchemaAndPort_notAllowedSchema_throwException() {
+    public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception {
         expectedException.expect(IllegalArgumentException.class);
-        HttpHost httpHost = new HttpHost("api.openai.com", 8080, "ftp");
-        MLHttpClientFactory.validateSchemaAndPort(httpHost);
+        MLHttpClientFactory.validate("ftp", "api.openai.com", 80);
     }
 
     @Test
-    public void test_validateSchemaAndPort_portNotInRange_throwException() {
+    public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception {
         expectedException.expect(IllegalArgumentException.class);
-        HttpHost httpHost = new HttpHost("api.openai.com:65537", -1, "https");
-        MLHttpClientFactory.validateSchemaAndPort(httpHost);
+        expectedException.expectMessage("Port out of range: 65537");
+        MLHttpClientFactory.validate("https", "api.openai.com", 65537);
     }
+
 }
diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
index 4d1f935e83..cb94668308 100644
--- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
+++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
@@ -1891,6 +1891,12 @@ public <T> T trackPredictDuration(String modelId, Supplier<T> supplier) {
         return t;
     }
 
+    public void trackPredictDuration(String modelId, long startTime) {
+        long end = System.nanoTime();
+        double durationInMs = (end - startTime) / 1e6;
+        modelCacheHelper.addModelInferenceDuration(modelId, durationInMs);
+    }
+
     public FunctionName getModelFunctionName(String modelId) {
         return modelCacheHelper.getFunctionName(modelId);
     }
diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
index a19ffc4af3..16176f197e 100644
--- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
+++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
@@ -321,17 +321,26 @@ private void runPredict(
                     if (!predictor.isModelReady()) {
                         throw new IllegalArgumentException("Model not ready: " + modelId);
                     }
-                    MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
-                    if (output instanceof MLPredictionOutput) {
-                        ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
+                    if (mlInput.getAlgorithm() == FunctionName.REMOTE) {
+                        long startTime = System.nanoTime();
+                        ActionListener<MLTaskResponse> trackPredictDurationListener = ActionListener.wrap(output -> {
+                            handleAsyncMLTaskComplete(mlTask);
+                            mlModelManager.trackPredictDuration(modelId, startTime);
+                            internalListener.onResponse(output);
+                        }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId));
+                        predictor.asyncPredict(mlInput, trackPredictDurationListener);
+                    } else {
+                        MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
+                        if (output instanceof MLPredictionOutput) {
+                            ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
+                        }
+                        // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
+                        handleAsyncMLTaskComplete(mlTask);
+                        internalListener.onResponse(new MLTaskResponse(output));
                     }
-
-                    // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
-                    handleAsyncMLTaskComplete(mlTask);
-                    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
-                    internalListener.onResponse(response);
                     return;
                 } catch (Exception e) {
+                    log.error("Failed to predict model " + modelId, e);
                     handlePredictFailure(mlTask, internalListener, e, false, modelId);
                     return;
                 }
diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
index 569aaee3c9..d42fa9ca65 100644
--- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
@@ -61,6 +61,7 @@
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Supplier;
 
 import org.junit.Before;
 import org.junit.Ignore;
@@ -1140,6 +1141,35 @@ public void testRegisterModelMeta_FailedToInitIndexIfPresent() {
         verify(actionListener).onFailure(argumentCaptor.capture());
     }
 
+    public void test_trackPredictDuration_sync() {
+        Supplier<String> mockResult = () -> {
+            try {
+                Thread.sleep(1000);
+            } catch (InterruptedException e) {
+                throw new RuntimeException(e);
+            }
+            return "test";
+        };
+        String modelId = "test_model";
+        modelManager.trackPredictDuration(modelId, mockResult);
+        ArgumentCaptor<String> modelIdCaptor = ArgumentCaptor.forClass(String.class);
+        ArgumentCaptor<Double> durationCaptor = ArgumentCaptor.forClass(Double.class);
+        verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture());
+        assert modelIdCaptor.getValue().equals(modelId);
+        assert durationCaptor.getValue() > 0;
+    }
+
+    public void test_trackPredictDuration_async() {
+        String modelId = "test_model";
+        long startTime = System.nanoTime();
+        modelManager.trackPredictDuration(modelId, startTime);
+        ArgumentCaptor<String> modelIdCaptor = ArgumentCaptor.forClass(String.class);
+        ArgumentCaptor<Double> durationCaptor = ArgumentCaptor.forClass(Double.class);
+        verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture());
+        assert modelIdCaptor.getValue().equals(modelId);
+        assert durationCaptor.getValue() > 0;
+    }
+
     private void setupForModelMeta() {
         doAnswer(invocation -> {
             ActionListener<IndexResponse> listener = invocation.getArgument(1);
diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java
index c4a9f4d61a..0f90528fab 100644
--- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java
@@ -15,6 +15,7 @@
 import java.nio.file.Path;
 import java.util.Arrays;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
@@ -51,12 +52,16 @@
 import org.opensearch.ml.common.dataset.DataFrameInputDataset;
 import org.opensearch.ml.common.dataset.MLInputDataset;
 import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
+import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
 import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams;
+import org.opensearch.ml.common.output.MLPredictionOutput;
+import org.opensearch.ml.common.output.model.ModelTensorOutput;
 import org.opensearch.ml.common.transport.MLTaskResponse;
 import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
 import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
 import org.opensearch.ml.engine.MLEngine;
+import org.opensearch.ml.engine.Predictable;
 import org.opensearch.ml.engine.encryptor.Encryptor;
 import org.opensearch.ml.engine.encryptor.EncryptorImpl;
 import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
@@ -306,6 +311,72 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
         assertEquals("ModelId is invalid", argumentCaptor.getValue().getMessage());
     }
 
+    public void testExecuteTask_OnLocalNode_remoteModel_success() {
+        setupMocks(true, false, false, false);
+        TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
+        MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
+            .builder()
+            .modelId("test_model")
+            .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(textDocsInputDataSet).build())
+            .build();
+        Predictable predictor = mock(Predictable.class);
+        when(predictor.isModelReady()).thenReturn(true);
+        doAnswer(invocation -> {
+            ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
+            actionListener
+                .onResponse(MLTaskResponse.builder().output(ModelTensorOutput.builder().mlModelOutputs(List.of()).build()).build());
+            return null;
+        }).when(predictor).asyncPredict(any(), any());
+        when(mlModelManager.getPredictor(anyString())).thenReturn(predictor);
+        when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" });
+        taskRunner.dispatchTask(FunctionName.REMOTE, textDocsInputRequest, transportService, listener);
+        verify(client, never()).get(any(), any());
+        ArgumentCaptor<MLTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
+        verify(listener).onResponse(argumentCaptor.capture());
+        assert argumentCaptor.getValue().getOutput() instanceof ModelTensorOutput;
+    }
+
+    public void testExecuteTask_OnLocalNode_localModel_success() {
+        setupMocks(true, false, false, false);
+        TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
+        MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
+            .builder()
+            .modelId("test_model")
+            .mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build())
+            .build();
+        Predictable predictor = mock(Predictable.class);
+        when(predictor.isModelReady()).thenReturn(true);
+        when(mlModelManager.getPredictor(anyString())).thenReturn(predictor);
+        when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" });
+        when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class));
+        taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener);
+        verify(client, never()).get(any(), any());
+        ArgumentCaptor<MLTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
+        verify(listener).onResponse(argumentCaptor.capture());
+        assert argumentCaptor.getValue().getOutput() instanceof MLPredictionOutput;
+    }
+
+    public void testExecuteTask_OnLocalNode_prediction_exception() {
+        setupMocks(true, false, false, false);
+        TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
+        MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
+            .builder()
+            .modelId("test_model")
+            .mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build())
+            .build();
+        Predictable predictable = mock(Predictable.class);
+        when(mlModelManager.getPredictor(anyString())).thenReturn(predictable);
+        when(predictable.isModelReady()).thenThrow(new RuntimeException("runtime exception"));
+        when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" });
+        when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class));
+        taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener);
+        verify(client, never()).get(any(), any());
+        ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
+        verify(listener).onFailure(argumentCaptor.capture());
+        assert argumentCaptor.getValue() instanceof RuntimeException;
+        assertEquals("runtime exception", argumentCaptor.getValue().getMessage());
+    }
+
     public void testExecuteTask_OnLocalNode_NullGetResponse() {
         setupMocks(true, false, false, true);