Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jun 5, 2024
1 parent 126eb07 commit 6c404b5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.rest.SecureRestClientBuilder;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaType;
Expand All @@ -88,7 +87,6 @@
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
Expand Down Expand Up @@ -915,7 +913,6 @@ public Map predictTextEmbedding(String modelId) throws IOException {

public Map predictRemoteModel(String modelId, MLInput input) throws IOException {
String requestBody = TestHelper.toJsonString(input);
System.out.println("############################## request body is:" + requestBody);
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
return parseResponseToMap(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@

package org.opensearch.ml.rest;

import lombok.SneakyThrows;
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import lombok.SneakyThrows;

public class RestBedRockInferenceIT extends MLCommonsRestTestCase {
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
Expand All @@ -32,19 +33,38 @@ public void setup() throws IOException, InterruptedException {
Thread.sleep(20000);
}


public void test_bedrock_embedding_model() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI()));
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) {
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = templateEntry.getKey();
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName);
String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true);
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateEntry.getValue()),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Expand Down

0 comments on commit 6c404b5

Please sign in to comment.