Skip to content

Commit

Permalink
fix IT failure
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Dec 3, 2024
1 parent ab2d736 commit 04e1d8e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,25 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc
return parseResponseToMap(response);
}

public Map predictTextEmbeddingModelIgnoreFunctionName(String modelId, MLInput mlInput) throws IOException {
Response response = null;
try {
response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/models/" + modelId + "/_predict",
null,
TestHelper.toJsonString(mlInput),
null
);
} catch (ResponseException e) {
log.error(e.getMessage(), e);
response = e.getResponse();
}
return parseResponseToMap(response);
}

public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
return (modelProfile) -> {
if (modelProfile.containsKey("model_state")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -327,9 +328,10 @@ public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exc
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512"))
.actionType(ConnectorAction.ActionType.PREDICT)
.build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput);
String errorMsg = String
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
Expand Down

0 comments on commit 04e1d8e

Please sign in to comment.