Skip to content

Commit

Permalink
Optimize logic and tests for blocking model deletion when in use (ope…
Browse files Browse the repository at this point in the history
…nsearch-project#1757)

* Optimize logic and tests for blocking model deletion when in use

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Change if statement

Co-authored-by: Heemin Kim <heemin@amazon.com>
Signed-off-by: Ryan Bogan <rbogan@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
Co-authored-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
ryanbogan and heemin32 authored Jun 14, 2024
1 parent 9075fb5 commit 6e30a3b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,31 +197,17 @@ private List<String> getIndicesUsingModel(ClusterState clusterState, UpdateModel
.filter(entry -> entry.getValue() != null)
.filter(entry -> {
Object properties = entry.getValue().getSourceAsMap().get("properties");
if (properties == null || properties instanceof Map == false) {
if ((properties instanceof Map) == false) {
return false;
}
Map propertiesMap = (Map<String, Object>) properties;
return propertiesMapContainsModel(propertiesMap, task.getModelId());
Map<String, Object> propertiesMap = (Map<String, Object>) properties;
return propertiesMap.values()
.stream()
.filter(obj -> obj instanceof Map)
.anyMatch(obj -> task.getModelId().equals(((Map<String, Object>) obj).get(MODEL_ID)));
})
.map(Map.Entry::getKey)
.collect(toList());
}

private boolean propertiesMapContainsModel(Map<String, Object> propertiesMap, String modelId) {
for (Map.Entry<String, Object> fieldsEntry : propertiesMap.entrySet()) {
if (fieldsEntry.getKey() != null && fieldsEntry.getValue() instanceof Map) {
Map<String, Object> innerMap = (Map<String, Object>) fieldsEntry.getValue();
for (Map.Entry<String, Object> innerEntry : innerMap.entrySet()) {
// If model is in use, fail delete model request
if (innerEntry.getKey().equals(MODEL_ID)
&& innerEntry.getValue() instanceof String
&& innerEntry.getValue().equals(modelId)) {
return true;
}
}
}
}
return false;
}
}
}
25 changes: 23 additions & 2 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.ClusterState;
Expand Down Expand Up @@ -37,7 +38,11 @@
import org.opensearch.test.hamcrest.OpenSearchAssertions;

import java.io.IOException;
import java.util.*;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -184,7 +189,7 @@ protected void addDoc(String index, String docId, String fieldName, String dummy
/**
* Index a new model
*/
protected void addDoc(Model model) throws IOException, ExecutionException, InterruptedException {
protected void writeModelToModelSystemIndex(Model model) throws IOException, ExecutionException, InterruptedException {
ModelMetadata modelMetadata = model.getModelMetadata();

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -213,6 +218,22 @@ protected void addDoc(Model model) throws IOException, ExecutionException, Inter
assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK);
}

// Add a new model to ModelDao
protected void addModel(Model model) throws IOException {
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
modelDao.put(model, new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
assertTrue(indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK);
}

@Override
public void onFailure(Exception e) {
fail("Failed to add model: " + e);
}
});
}

/**
* Run a search against a k-NN index
*/
Expand Down
105 changes: 6 additions & 99 deletions src/test/java/org/opensearch/knn/indices/ModelDaoTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.DocWriteResponse;
Expand All @@ -30,12 +29,9 @@
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.support.master.AcknowledgedResponse;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.knn.KNNSingleNodeTestCase;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.common.exception.DeleteModelException;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -65,11 +61,7 @@
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.PROPERTIES;
import static org.opensearch.knn.common.KNNConstants.TYPE;
import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;

public class ModelDaoTests extends KNNSingleNodeTestCase {

Expand Down Expand Up @@ -152,7 +144,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti
modelBlob,
modelId
);
addDoc(model);
writeModelToModelSystemIndex(model);
assertEquals(model, modelDao.get(modelId));
assertNotNull(modelDao.getHealthStatus());

Expand All @@ -172,7 +164,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti
modelBlob,
modelId
);
addDoc(model);
writeModelToModelSystemIndex(model);
assertEquals(model, modelDao.get(modelId));
assertNotNull(modelDao.getHealthStatus());
}
Expand Down Expand Up @@ -450,7 +442,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti
modelBlob,
modelId
);
addDoc(model);
writeModelToModelSystemIndex(model);
assertEquals(model, modelDao.get(modelId));

// Get model during training
Expand All @@ -469,7 +461,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti
null,
modelId
);
addDoc(model);
writeModelToModelSystemIndex(model);
assertEquals(model, modelDao.get(modelId));
}

Expand Down Expand Up @@ -629,91 +621,6 @@ public void testDelete() throws IOException, InterruptedException {
assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS));
}

// Test Delete Model when the model is in use by an index
public void testDeleteModelInUse() throws IOException, ExecutionException, InterruptedException {
String modelId = "test-model-id-training";
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
byte[] modelBlob = "deleteModel".getBytes();
int dimension = 2;
createIndex(MODEL_INDEX_NAME);

Model model = new Model(
new ModelMetadata(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
dimension,
ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(),
"",
"",
"",
MethodComponentContext.EMPTY
),
modelBlob,
modelId
);

// created model and added it to index
addDoc(model);

String testIndex = "test-index";
String testField = "test-field";

/*
Constructs the following json:
{
"properties": {
"test-field": {
"type": "knn_vector",
"model_id": "test-model-id-training"
}
}
}
*/
XContentBuilder mappings = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES)
.startObject(testField)
.field(TYPE, TYPE_KNN_VECTOR)
.field(MODEL_ID, modelId)
.endObject()
.endObject()
.endObject();

XContentBuilder settings = XContentFactory.jsonBuilder().startObject().field(TestUtils.INDEX_KNN, "true").endObject();

// Create index using model
CreateIndexRequestBuilder createIndexRequestBuilder = client().admin()
.indices()
.prepareCreate(testIndex)
.setMapping(mappings)
.setSettings(settings);
createIndex(testIndex, createIndexRequestBuilder);

CountDownLatch latch = new CountDownLatch(1);
modelDao.delete(modelId, new ActionListener<DeleteModelResponse>() {
@Override
public void onResponse(DeleteModelResponse deleteModelResponse) {
fail("Received delete model response when the request should have failed.");
}

@Override
public void onFailure(Exception e) {
assertTrue(e instanceof DeleteModelException);
assertEquals(
String.format(
"Cannot delete model [%s]. Model is in use by the following indices [%s], which must be deleted first.",
modelId,
testIndex
),
e.getMessage()
);
latch.countDown();
}
});
assertTrue(latch.await(60, TimeUnit.SECONDS));
}

// Test Delete Model when modelId is in Model Graveyard (previous delete model request which failed to
// remove modelId from model graveyard). But, the model does not exist
public void testDeleteModelWithModelInGraveyardModelDoesNotExist() throws InterruptedException {
Expand Down Expand Up @@ -772,7 +679,7 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe
);

// created model and added it to index
addDoc(model);
writeModelToModelSystemIndex(model);

final CountDownLatch inProgressLatch = new CountDownLatch(1);

Expand Down Expand Up @@ -814,7 +721,7 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti
);

// created model and added it to index
addDoc(model);
writeModelToModelSystemIndex(model);

final CountDownLatch inProgressLatch = new CountDownLatch(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public void testCheckBlock() {
assertNull(updateModelGraveyardTransportAction.checkBlock(null, null));
}

public void testGetIndicesUsingModel() throws IOException, ExecutionException, InterruptedException {
public void testClusterManagerOperation_GetIndicesUsingModel() throws IOException, ExecutionException, InterruptedException {
// Get update transport action
UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector()
.getInstance(UpdateModelGraveyardTransportAction.class);
Expand Down Expand Up @@ -217,7 +217,7 @@ public void testGetIndicesUsingModel() throws IOException, ExecutionException, I
);

// created model and added it to index
addDoc(model);
addModel(model);

// Create basic index (not using k-NN)
String testIndex1 = "test-index1";
Expand Down Expand Up @@ -336,7 +336,7 @@ public void testGetIndicesUsingModel() throws IOException, ExecutionException, I
);
}

public void updateModelGraveyardAndAssertNoError(
private void updateModelGraveyardAndAssertNoError(
UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction,
UpdateModelGraveyardRequest updateModelGraveyardRequest
) throws InterruptedException {
Expand All @@ -355,7 +355,7 @@ public void updateModelGraveyardAndAssertNoError(
assertTrue(countDownLatch.await(60, TimeUnit.SECONDS));
}

public void updateModelGraveyardAndAssertDeleteModelException(
private void updateModelGraveyardAndAssertDeleteModelException(
UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction,
UpdateModelGraveyardRequest updateModelGraveyardRequest,
String indicesPresentInException
Expand Down

0 comments on commit 6e30a3b

Please sign in to comment.