Skip to content

Commit

Permalink
refactor delete model flow to make sure all dependent resources are d…
Browse files Browse the repository at this point in the history
…eleted together with model metadata

Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Feb 4, 2024
1 parent 067408c commit dd6fc1b
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
Expand All @@ -34,7 +34,6 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
Expand All @@ -57,6 +56,9 @@
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

@Log4j2
@FieldDefaults(level = AccessLevel.PRIVATE)
public class DeleteModelTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> {
Expand Down Expand Up @@ -177,25 +179,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}

@VisibleForTesting
void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener<DeleteResponse> actionListener) {
void deleteModelChunks(String modelId, ActionListener<Boolean> actionListener) {
DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX);
deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId));

client.execute(DeleteByQueryAction.INSTANCE, deleteModelsRequest, ActionListener.wrap(r -> {
if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0)
&& (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) {
log.debug("All model chunks are deleted for model {}", modelId);
actionListener.onResponse(deleteResponse);
actionListener.onResponse(true);
} else {
returnFailure(r, modelId, actionListener);
}
}, e -> {
log.error("Failed to delete ML model for " + modelId, e);
log.error("Failed to delete model chunks for: " + modelId, e);
actionListener.onFailure(e);
}));
}

private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener<DeleteResponse> actionListener) {
private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener<Boolean> actionListener) {
String errorMessage = "";
if (response.isTimedOut()) {
errorMessage = OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + modelId;
Expand All @@ -209,21 +211,53 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
}

private void deleteModel(String modelId, FunctionName functionName, ActionListener<DeleteResponse> actionListener) {
// Always delete model chunks and model controller first, because deleting metadata first user is not able clean up model chunks and model controller.
if (FunctionName.REMOTE == functionName) {
CountDownLatch countDownLatch = new CountDownLatch(2);
AtomicBoolean bothDeleted = new AtomicBoolean(true);
ActionListener<Boolean> countDownActionListener = ActionListener.wrap(b -> {
countDownLatch.countDown();
bothDeleted.compareAndSet(true, b);
if (countDownLatch.getCount() == 0) {
if (bothDeleted.get()) {
log.debug("model chunks and model controller for model {} deleted successfully, starting to delete model meta data", modelId);
deleteModelMetadata(modelId, actionListener);
} else {
actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId));
}
}
}, e -> {
countDownLatch.countDown();
bothDeleted.compareAndSet(true, false);
if (countDownLatch.getCount() == 0) {
actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e));
}
});
deleteModelChunks(modelId, countDownActionListener);
deleteController(modelId, countDownActionListener);
} else {
ActionListener<Boolean> deleteControllerListener = ActionListener.wrap(b -> {
log.debug("model controller for model {} deleted successfully, starting to delete model meta data", modelId);
deleteModelMetadata(modelId, actionListener);
}, e -> {
log.error("Failed to delete model chunks or model controller, please try again: " + modelId, e);
actionListener.onFailure(new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e));
});
deleteController(modelId, deleteControllerListener);
}
}

private void deleteModelMetadata(String modelId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (FunctionName.REMOTE != functionName) {
deleteModelChunks(modelId, deleteResponse, actionListener);
} else {
actionListener.onResponse(deleteResponse);
}
deleteController(modelId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model: " + modelId, e);
log.error("Failed to delete model meta data for model, please try again: " + modelId, e);
actionListener.onFailure(e);
}
});
Expand All @@ -235,20 +269,20 @@ public void onFailure(Exception e) {
*
* @param modelId model ID
*/
private void deleteController(String modelId, ActionListener<DeleteResponse> actionListener) {
private void deleteController(String modelId, ActionListener<Boolean> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_CONTROLLER_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult());
actionListener.onResponse(deleteResponse);
actionListener.onResponse(true);
}

@Override
public void onFailure(Exception e) {
if (e instanceof IndexNotFoundException) {
if (e instanceof ResourceNotFoundException) {
log.info("Model controller not deleted due to no model controller found for model: " + modelId);
actionListener.onFailure(e);
actionListener.onResponse(true); // we consider this as success
} else {
log.error("Failed to delete model controller for model: " + modelId, e);
actionListener.onFailure(e);
Expand All @@ -257,27 +291,6 @@ public void onFailure(Exception e) {
});
}

/**
* Delete the model controller for a model after the model is deleted from the
* ML index with build-in listener.
*
* @param modelId model ID
*/
private void deleteController(String modelId) {
deleteController(modelId, ActionListener.wrap(deleteResponse -> {
if (deleteResponse.getResult() == DocWriteResponse.Result.DELETED) {
log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult());
} else {
log.info("The deletion of model controller for model {} returned with result: {}", modelId, deleteResponse.getResult());
}
}, e -> {
if (e instanceof IndexNotFoundException) {
log.debug("Model controller not deleted due to no model controller found for model: " + modelId);
} else {
log.error("Failed to delete model controller for model: " + modelId, e);
}
}));
}

private Boolean isModelNotDeployed(MLModelState mlModelState) {
return !mlModelState.equals(MLModelState.LOADED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -47,6 +48,7 @@
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
Expand Down Expand Up @@ -154,6 +156,84 @@ public void testDeleteModel_Success() throws IOException {
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteRemoteModel_Success() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteRemoteModel_deleteModelController_failed() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("runtime exception"));
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage());
}

public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("runtime exception"));
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage());
}

public void testDeleteHiddenModel_Success() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
Expand Down Expand Up @@ -349,9 +429,9 @@ public void testDeleteModelChunks_Success() {
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
verify(actionListener).onResponse(deleteResponse);
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
verify(deleteChunksListener).onResponse(true);
}

public void testDeleteModel_RuntimeException() throws IOException {
Expand All @@ -371,7 +451,7 @@ public void testDeleteModel_RuntimeException() throws IOException {
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
assertEquals("Failed to delete model chunks or model controller, please try again: test_id", argumentCaptor.getValue().getMessage());
}

@Ignore
Expand All @@ -389,10 +469,10 @@ public void test_FailToDeleteModel() {
listener.onFailure(new RuntimeException("errorMessage"));
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
verify(deleteChunksListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

Expand All @@ -404,10 +484,10 @@ public void test_FailToDeleteAllModelChunks() {
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
verify(deleteChunksListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

Expand All @@ -420,10 +500,10 @@ public void test_FailToDeleteAllModelChunks_TimeOut() {
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
verify(deleteChunksListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

Expand All @@ -442,16 +522,24 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
verify(deleteChunksListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID, boolean isHidden) throws IOException {
MLModel mlModel;
mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build();
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build();
return buildResponse(mlModel);
}

public GetResponse prepareModelWithFunction(MLModelState mlModelState, String modelGroupID, boolean isHidden, FunctionName functionName) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").algorithm(functionName).modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build();
return buildResponse(mlModel);
}

private GetResponse buildResponse(MLModel mlModel) throws IOException {
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
Expand Down

0 comments on commit dd6fc1b

Please sign in to comment.