From dd6fc1be0b59b83683033166509d4deba4558bfc Mon Sep 17 00:00:00 2001 From: zane-neo Date: Sun, 4 Feb 2024 14:30:14 +0800 Subject: [PATCH] refactor delete model flow to make sure all dependent resources are deleted together with model metadata Signed-off-by: zane-neo --- .../models/DeleteModelTransportAction.java | 89 +++++++------ .../DeleteModelTransportActionTests.java | 124 +++++++++++++++--- 2 files changed, 157 insertions(+), 56 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index e2e84682b2..fb216dc621 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -15,8 +15,8 @@ import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; @@ -34,7 +34,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -57,6 +56,9 @@ import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) public class DeleteModelTransportAction extends HandledTransportAction { @@ -177,7 +179,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + void deleteModelChunks(String modelId, ActionListener actionListener) { DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId)); @@ -185,17 +187,17 @@ void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionList if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0) && (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) { log.debug("All model chunks are deleted for model {}", modelId); - actionListener.onResponse(deleteResponse); + actionListener.onResponse(true); } else { returnFailure(r, modelId, actionListener); } }, e -> { - log.error("Failed to delete ML model for " + modelId, e); + log.error("Failed to delete model chunks for: " + modelId, e); actionListener.onFailure(e); })); } - private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { + private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { String errorMessage = ""; if (response.isTimedOut()) { errorMessage = OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + modelId; @@ -209,21 +211,53 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action } private void deleteModel(String modelId, FunctionName functionName, ActionListener actionListener) { + // Always delete model chunks and model controller first, because deleting metadata first user is not able clean up model chunks and model controller. + if (FunctionName.REMOTE == functionName) { + CountDownLatch countDownLatch = new CountDownLatch(2); + AtomicBoolean bothDeleted = new AtomicBoolean(true); + ActionListener countDownActionListener = ActionListener.wrap(b -> { + countDownLatch.countDown(); + bothDeleted.compareAndSet(true, b); + if (countDownLatch.getCount() == 0) { + if (bothDeleted.get()) { + log.debug("model chunks and model controller for model {} deleted successfully, starting to delete model meta data", modelId); + deleteModelMetadata(modelId, actionListener); + } else { + actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId)); + } + } + }, e -> { + countDownLatch.countDown(); + bothDeleted.compareAndSet(true, false); + if (countDownLatch.getCount() == 0) { + actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e)); + } + }); + deleteModelChunks(modelId, countDownActionListener); + deleteController(modelId, countDownActionListener); + } else { + ActionListener deleteControllerListener = ActionListener.wrap(b -> { + log.debug("model controller for model {} deleted successfully, starting to delete model meta data", modelId); + deleteModelMetadata(modelId, actionListener); + }, e -> { + log.error("Failed to delete model chunks or model controller, please try again: " + modelId, e); + actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e)); + }); + deleteController(modelId, deleteControllerListener); + } + } + + private void deleteModelMetadata(String modelId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { - if (FunctionName.REMOTE != functionName) { - deleteModelChunks(modelId, deleteResponse, actionListener); - } else { - actionListener.onResponse(deleteResponse); - } - deleteController(modelId); + actionListener.onResponse(deleteResponse); } @Override public void onFailure(Exception e) { - log.error("Failed to delete model meta data for model: " + modelId, e); + log.error("Failed to delete model meta data for model, please try again: " + modelId, e); actionListener.onFailure(e); } }); @@ -235,20 +269,20 @@ public void onFailure(Exception e) { * * @param modelId model ID */ - private void deleteController(String modelId, ActionListener actionListener) { + private void deleteController(String modelId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_CONTROLLER_INDEX, modelId); client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); - actionListener.onResponse(deleteResponse); + actionListener.onResponse(true); } @Override public void onFailure(Exception e) { - if (e instanceof IndexNotFoundException) { + if (e instanceof ResourceNotFoundException) { log.info("Model controller not deleted due to no model controller found for model: " + modelId); - actionListener.onFailure(e); + actionListener.onResponse(true); // we consider this as success } else { log.error("Failed to delete model controller for model: " + modelId, e); actionListener.onFailure(e); @@ -257,27 +291,6 @@ public void onFailure(Exception e) { }); } - /** - * Delete the model controller for a model after the model is deleted from the - * ML index with build-in listener. - * - * @param modelId model ID - */ - private void deleteController(String modelId) { - deleteController(modelId, ActionListener.wrap(deleteResponse -> { - if (deleteResponse.getResult() == DocWriteResponse.Result.DELETED) { - log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); - } else { - log.info("The deletion of model controller for model {} returned with result: {}", modelId, deleteResponse.getResult()); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - log.debug("Model controller not deleted due to no model controller found for model: " + modelId); - } else { - log.error("Failed to delete model controller for model: " + modelId, e); - } - })); - } private Boolean isModelNotDeployed(MLModelState mlModelState) { return !mlModelState.equals(MLModelState.LOADED) diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index b01d28e20f..f924c13325 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -8,6 +8,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -47,6 +48,7 @@ import org.opensearch.index.get.GetResult; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; @@ -154,6 +156,84 @@ public void testDeleteModel_Success() throws IOException { verify(actionListener).onResponse(deleteResponse); } + public void testDeleteRemoteModel_Success() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + public void testDeleteRemoteModel_deleteModelController_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage()); + } + public void testDeleteHiddenModel_Success() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -349,9 +429,9 @@ public void testDeleteModelChunks_Success() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); - verify(actionListener).onResponse(deleteResponse); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); + verify(deleteChunksListener).onResponse(true); } public void testDeleteModel_RuntimeException() throws IOException { @@ -371,7 +451,7 @@ public void testDeleteModel_RuntimeException() throws IOException { deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage()); } @Ignore @@ -389,10 +469,10 @@ public void test_FailToDeleteModel() { listener.onFailure(new RuntimeException("errorMessage")); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } @@ -404,10 +484,10 @@ public void test_FailToDeleteAllModelChunks() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } @@ -420,10 +500,10 @@ public void test_FailToDeleteAllModelChunks_TimeOut() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage()); } @@ -442,16 +522,24 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID, boolean isHidden) throws IOException { - MLModel mlModel; - mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build(); + MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build(); + return buildResponse(mlModel); + } + + public GetResponse prepareModelWithFunction(MLModelState mlModelState, String modelGroupID, boolean isHidden, FunctionName functionName) throws IOException { + MLModel mlModel = MLModel.builder().modelId("test_id").algorithm(functionName).modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build(); + return buildResponse(mlModel); + } + + private GetResponse buildResponse(MLModel mlModel) throws IOException { XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);