Skip to content

Commit

Permalink
delete model chunks for other models except remote model (opensearch-…
Browse files Browse the repository at this point in the history
…project#2827)

* delete model chunks for other models except remote model

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* fixing tests

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressing comment

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
dhrubo-os authored Aug 15, 2024
1 parent d2cb8a2 commit 9f4d2ce
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
import static org.opensearch.ml.common.MLModel.FUNCTION_NAME_FIELD;
import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -42,6 +44,7 @@
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
Expand Down Expand Up @@ -73,14 +76,14 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
static final String BULK_FAILURE_MSG = "Bulk failure while deleting model of ";
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks";
Client client;
SdkClient sdkClient;
NamedXContentRegistry xContentRegistry;
ClusterService clusterService;
final Client client;
final SdkClient sdkClient;
final NamedXContentRegistry xContentRegistry;
final ClusterService clusterService;

Settings settings;

ModelAccessControlHelper modelAccessControlHelper;
final ModelAccessControlHelper modelAccessControlHelper;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
Expand Down Expand Up @@ -124,7 +127,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
sdkClient
.getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
Expand All @@ -143,8 +146,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
String algorithmName = "";
Map<String, Object> source = r.source();
if (source != null && source.get(ALGORITHM_FIELD) != null) {
algorithmName = source.get(ALGORITHM_FIELD).toString();
if (source != null) {
if (source.get(FUNCTION_NAME_FIELD) != null) {
algorithmName = source.get(FUNCTION_NAME_FIELD).toString();
} else if (source.get(ALGORITHM_FIELD) != null) {
algorithmName = source.get(ALGORITHM_FIELD).toString();
}
}
MLModel mlModel = MLModel.parse(parser, algorithmName);
if (!TenantAwareHelper
Expand All @@ -164,7 +171,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
);
} else {
if (isModelNotDeployed(mlModelState)) {
deleteModel(modelId, isHidden, actionListener);
deleteModel(modelId, algorithmName, isHidden, actionListener);
} else {
wrappedListener
.onFailure(
Expand All @@ -191,7 +198,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
)
);
} else if (isModelNotDeployed(mlModelState)) {
deleteModel(modelId, isHidden, actionListener);
deleteModel(modelId, mlModel.getAlgorithm().name(), isHidden, actionListener);
} else {
wrappedListener
.onFailure(
Expand All @@ -208,15 +215,18 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
);
}
} catch (Exception e) {
log.error("Failed to parse ml model " + r.id(), e);
log.error("Failed to parse ml model {}", r.id(), e);
wrappedListener.onFailure(e);
}
} else {
// when model metadata is not found, model chunk and controller might still there, delete them here and
// return
// success
// response
deleteModelChunksAndController(wrappedListener, modelId, false, null);
// as we can't see the metadata we are providing functionName as null. In this way,
// code will try to remove model chunks for any models other than remote. As remote
// model doesn't have any model chunks.
deleteModelChunksAndController(wrappedListener, modelId, null, false, null);
}
} catch (Exception e) {
wrappedListener.onFailure(e);
Expand All @@ -226,7 +236,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
});
} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
log.error("Failed to delete ML model {}", modelId, e);
actionListener.onFailure(e);
}
}
Expand All @@ -237,8 +247,8 @@ void deleteModelChunks(String modelId, Boolean isHidden, ActionListener<Boolean>
deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId));

client.execute(DeleteByQueryAction.INSTANCE, deleteModelsRequest, ActionListener.wrap(r -> {
if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0)
&& (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) {
if ((r.getBulkFailures() == null || r.getBulkFailures().isEmpty())
&& (r.getSearchFailures() == null || r.getSearchFailures().isEmpty())) {
log.debug(getErrorMessage("All model chunks are deleted for the provided model.", modelId, isHidden));
actionListener.onResponse(true);
} else {
Expand All @@ -251,7 +261,7 @@ void deleteModelChunks(String modelId, Boolean isHidden, ActionListener<Boolean>
}

private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener<Boolean> actionListener) {
String errorMessage = "";
String errorMessage;
if (response.isTimedOut()) {
errorMessage = OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + modelId;
} else if (!response.getBulkFailures().isEmpty()) {
Expand All @@ -263,22 +273,22 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}

private void deleteModel(String modelId, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
private void deleteModel(String modelId, String functionName, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
DeleteDataObjectRequest deleteDataObjectRequest = DeleteDataObjectRequest.builder().index(ML_MODEL_INDEX).id(modelId).build();
sdkClient
.deleteDataObjectAsync(deleteDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
if (throwable == null) {
try {
DeleteResponse deleteResponse = DeleteResponse.fromXContent(r.parser());
deleteModelChunksAndController(actionListener, modelId, isHidden, deleteResponse);
deleteModelChunksAndController(actionListener, modelId, functionName, isHidden, deleteResponse);
} catch (Exception e) {
actionListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable);
if (e instanceof ResourceNotFoundException) {
deleteModelChunksAndController(actionListener, modelId, isHidden, null);
deleteModelChunksAndController(actionListener, modelId, functionName, isHidden, null);
} else {
log.error(getErrorMessage("Model is not all cleaned up, please try again.", modelId, isHidden), e);
actionListener.onFailure(e);
Expand All @@ -290,6 +300,7 @@ private void deleteModel(String modelId, Boolean isHidden, ActionListener<Delete
private void deleteModelChunksAndController(
ActionListener<DeleteResponse> actionListener,
String modelId,
String functionName,
Boolean isHidden,
DeleteResponse deleteResponse
) {
Expand Down Expand Up @@ -332,7 +343,12 @@ private void deleteModelChunksAndController(
});
// TODO: this uses DeleteByQuery which isn't on SdkClient
// evaluate if it's safe to leave as is
deleteModelChunks(modelId, isHidden, countDownActionListener);
if (!Objects.equals(functionName, FunctionName.REMOTE.name())) {
deleteModelChunks(modelId, isHidden, countDownActionListener);
} else {
// for remote model we don't need to delete model chunks so reducing one latch countdown.
countDownLatch.countDown();
}
deleteController(modelId, isHidden, countDownActionListener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,6 @@ public void testDeleteRemoteModel_Success() throws IOException, InterruptedExcep
future.onResponse(deleteResponse);
when(client.delete(any())).thenReturn(future);

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
PlainActionFuture<GetResponse> getFuture = PlainActionFuture.newFuture();
getFuture.onResponse(getResponse);
Expand Down Expand Up @@ -689,6 +682,7 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID
.modelState(mlModelState)
.modelGroupId(modelGroupID)
.isHidden(isHidden)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
return buildResponse(mlModel);
}
Expand Down

0 comments on commit 9f4d2ce

Please sign in to comment.