Skip to content

Commit

Permalink
handle case where user accidentally sets doesVersionCreateModelGroup …
Browse files Browse the repository at this point in the history
…to true

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 5, 2023
1 parent ab6c19e commit 8fa07f8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
listener.onFailure(e);
}));
} else {
registerModelInput.setDoesVersionCreateModelGroup(false);
registerModel(registerModelInput, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,42 @@ public void test_ModelNameAlreadyExists() throws IOException {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_DoesVersionCreateModelGroupFieldSetToTrueByUserByMistake() throws IOException {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class);
doAnswer(invocation -> {
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3);
handler.handleResponse(forwardResponse);
return null;
}).when(transportService).sendRequest(any(), any(), any(), any());

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
.functionName(FunctionName.BATCH_RCF)
.modelGroupId("model_group_ID")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig(
"CUSTOM",
123,
TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.url("http://test_url")
.doesVersionCreateModelGroup(true)
.build();

transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener);
ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException {
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
Expand Down

0 comments on commit 8fa07f8

Please sign in to comment.