Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Oct 10, 2023
1 parent 13fd300 commit c52b1bf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build();
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
Expand Down Expand Up @@ -184,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException {
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
Expand Down

0 comments on commit c52b1bf

Please sign in to comment.