From beb201174a5ac1dcfdf27f126155b4f9c7d70dbc Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 27 Oct 2023 11:09:00 +0800 Subject: [PATCH] Add more UTs Signed-off-by: zane-neo --- .../remote/AwsConnectorExecutorTest.java | 23 +++++++++++++++++++ .../remote/HttpJsonConnectorExecutorTest.java | 20 ++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index e99fc54a2b..c4183ebf90 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -207,4 +207,27 @@ public void executePredict_TextDocsInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void test_initialize() { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + initializeExecutor(executor); + } + + private void initializeExecutor(RemoteConnectorExecutor executor) { + executor.setConnectionTimeoutInMillis(1000); + executor.setReadTimeoutInMillis(1000); + executor.setMaxConnections(30); + executor.initialize(); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 556e854a46..68b151f31c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -182,4 +182,24 @@ public void executePredict_TextDocsInput() throws IOException { Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } + + @Test + public void test_initialize() { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); + initializeExecutor(executor); + } + + private void initializeExecutor(RemoteConnectorExecutor executor) { + executor.setConnectionTimeoutInMillis(1000); + executor.setReadTimeoutInMillis(1000); + executor.setMaxConnections(30); + executor.initialize(); + } }