Skip to content

Commit

Permalink
Add more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Oct 27, 2023
1 parent 347fbfc commit beb2011
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,27 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}

@Test
public void test_initialize() {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
initializeExecutor(executor);
}

private void initializeExecutor(RemoteConnectorExecutor executor) {
executor.setConnectionTimeoutInMillis(1000);
executor.setReadTimeoutInMillis(1000);
executor.setMaxConnections(30);
executor.initialize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,24 @@ public void executePredict_TextDocsInput() throws IOException {
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
}

@Test
public void test_initialize() {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
initializeExecutor(executor);
}

private void initializeExecutor(RemoteConnectorExecutor executor) {
executor.setConnectionTimeoutInMillis(1000);
executor.setReadTimeoutInMillis(1000);
executor.setMaxConnections(30);
executor.initialize();
}
}

0 comments on commit beb2011

Please sign in to comment.