Skip to content

Commit

Permalink
add tests and error message tweaks
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <hmlindeman@yahoo.com>
  • Loading branch information
HenryL27 committed Dec 6, 2023
1 parent 97e10b1 commit 09e0f91
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public TextSimilarityInputDataSet(String queryText, List<String> textDocs) {
Objects.requireNonNull(textDocs);
Objects.requireNonNull(queryText);
if(textDocs.isEmpty()) {
throw new IllegalArgumentException("pairs must be nonempty");
throw new IllegalArgumentException("No text documents provided");
}
this.textDocs = textDocs;
this.queryText = queryText;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) t
}
}
if(docs.isEmpty()) {
throw new IllegalArgumentException("no text docs");
throw new IllegalArgumentException("No text documents were provided");
}
if(queryText == null) {
throw new IllegalArgumentException("no query text");
throw new IllegalArgumentException("No query text was provided");
}
inputDataset = new TextSimilarityInputDataSet(queryText, docs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ public void noPairs_ThenFail() {
String queryText = "today is sunny";
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
assert (e.getMessage().equals("pairs must be nonempty"));
assert (e.getMessage().equals("No text documents provided"));
}

@Test
public void noQuery_ThenFail() {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = null;
assertThrows(NullPointerException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public void testParseJson_NoPairs_ThenFail() throws IOException {

IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> MLInput.parse(parser, input.getFunctionName().name()));
assert (e.getMessage().equals("no text docs"));
assert (e.getMessage().equals("No text documents were provided"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,34 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti
textSimilarityCrossEncoderModel.close();
}

@Test
public void initModel_predict_ONNX_CrossEncoder() throws URISyntaxException {
model = MLModel
.builder()
.modelFormat(MLModelFormat.ONNX)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TEXT_SIMILARITY)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
modelZipFile = new File(getClass().getResource("TinyBERT-CE-onnx.zip").toURI());
params.put(MODEL_ZIP_FILE, modelZipFile);

textSimilarityCrossEncoderModel.initModel(model, params, encryptor);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build();
ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput);
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
assertEquals(2, mlModelOutputs.size());
for (int i = 0; i < mlModelOutputs.size(); i++) {
ModelTensors tensors = mlModelOutputs.get(i);
List<ModelTensor> mlModelTensors = tensors.getMlModelTensors();
assertEquals(1, mlModelTensors.size());
assertEquals(1, mlModelTensors.get(0).getData().length);
}
textSimilarityCrossEncoderModel.close();
}

@Test
public void initModel_NullModelHelper() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
Expand Down
Binary file not shown.

0 comments on commit 09e0f91

Please sign in to comment.