Skip to content

Commit

Permalink
support multiple docs for remote embedding model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Oct 10, 2023
1 parent d9c920c commit 13fd300
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
docs.add(null);
}
}
if (preProcessFunction.contains("${parameters")) {
if (preProcessFunction.contains("${parameters.")) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
preProcessFunction = substitutor.replace(preProcessFunction);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) {

if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
List<String> textDocs = new ArrayList<>(textDocsInputDataSet.getDocs());
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs);
int processedDocs = 0;
while(processedDocs < textDocsInputDataSet.getDocs().size()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
processedDocs += Math.max(tempTensorOutputs.size(), 1);
tensorOutputs.addAll(tempTensorOutputs);
}

} else {
preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
}
Expand Down

0 comments on commit 13fd300

Please sign in to comment.