Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map result support in neural search for non text embedding models #258

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<Map<String, ?>> listener) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<Map<String, ?>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final Map<String, ?> result = buildMapResultFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, result);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -118,7 +148,7 @@ private void inferenceSentencesWithRetry(
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +174,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private Map<String, ?> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Copy link
Member

@martin-gaievski martin-gaievski Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check the type of mlOutput before casting?

final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList)) {
log.error("No tensor output found!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this error message more understandable and with actions what happened wrong which resulted in this error

return null;
}
List<ModelTensor> tensorList = tensorOutputList.get(0).getMlModelTensors();
if (CollectionUtils.isEmpty(tensorList)) {
log.error("No tensor found!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this error message more understandable and with actions what happened wrong which resulted in this error

return null;
}
return tensorList.get(0).getDataAsMap();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.InjectMocks;
Expand Down Expand Up @@ -160,6 +161,98 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
Mockito.verify(resultListener).onFailure(illegalStateException);
}

public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
final Map<String, String> map = Map.of("key", "value");
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(map));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(map);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
tensorsList.add(new ModelTensors(mlModelTensorList));
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
accessor.inferenceSentencesWithMapResult(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(nodeNodeConnectedException);
}

public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() {
final IllegalStateException illegalStateException = new IllegalStateException("Illegal state");
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
accessor.inferenceSentencesWithMapResult(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(1))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(illegalStateException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand All @@ -168,7 +261,27 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) {
output,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12])
ByteBuffer.wrap(new byte[12]),
"someValue",
Map.of()
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}

private ModelTensorOutput createModelTensorOutput(final Map<String, String> map) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
final ModelTensor tensor = new ModelTensor(
"response",
null,
null,
null,
null,
null,
map
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
Expand Down