-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add map result support in neural search for non text embedding models #258
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,15 @@ | |
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import lombok.NonNull; | ||
import lombok.RequiredArgsConstructor; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.common.util.CollectionUtils; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
|
@@ -100,10 +102,38 @@ public void inferenceSentences( | |
@NonNull final List<String> inputText, | ||
@NonNull final ActionListener<List<List<Float>>> listener | ||
) { | ||
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); | ||
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); | ||
} | ||
|
||
private void inferenceSentencesWithRetry( | ||
public void inferenceSentencesWithMapResult( | ||
@NonNull final String modelId, | ||
@NonNull final List<String> inputText, | ||
@NonNull final ActionListener<Map<String, ?>> listener) { | ||
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); | ||
} | ||
|
||
private void retryableInferenceSentencesWithMapResult( | ||
final String modelId, | ||
final List<String> inputText, | ||
final int retryTime, | ||
final ActionListener<Map<String, ?>> listener | ||
) { | ||
MLInput mlInput = createMLInput(null, inputText); | ||
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { | ||
final Map<String, ?> result = buildMapResultFromResponse(mlOutput); | ||
log.debug("Inference Response for input sentence {} is : {} ", inputText, result); | ||
listener.onResponse(result); | ||
}, e -> { | ||
if (RetryUtil.shouldRetry(e, retryTime)) { | ||
final int retryTimeAdd = retryTime + 1; | ||
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); | ||
} else { | ||
listener.onFailure(e); | ||
} | ||
})); | ||
} | ||
|
||
private void retryableInferenceSentencesWithVectorResult( | ||
final List<String> targetResponseFilters, | ||
final String modelId, | ||
final List<String> inputText, | ||
|
@@ -118,7 +148,7 @@ private void inferenceSentencesWithRetry( | |
}, e -> { | ||
if (RetryUtil.shouldRetry(e, retryTime)) { | ||
final int retryTimeAdd = retryTime + 1; | ||
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); | ||
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); | ||
} else { | ||
listener.onFailure(e); | ||
} | ||
|
@@ -144,4 +174,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) { | |
return vector; | ||
} | ||
|
||
private Map<String, ?> buildMapResultFromResponse(MLOutput mlOutput) { | ||
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; | ||
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs(); | ||
if (CollectionUtils.isEmpty(tensorOutputList)) { | ||
log.error("No tensor output found!"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets make this error message more understandable and with actions what happened wrong which resulted in this error |
||
return null; | ||
} | ||
List<ModelTensor> tensorList = tensorOutputList.get(0).getMlModelTensors(); | ||
if (CollectionUtils.isEmpty(tensorList)) { | ||
log.error("No tensor found!"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets make this error message more understandable and with actions what happened wrong which resulted in this error |
||
return null; | ||
} | ||
return tensorList.get(0).getDataAsMap(); | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check the type of
mlOutput
before casting?