diff --git a/CHANGELOG.md b/CHANGELOG.md index 91744f50b..6d954ace3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add 2.6.0 to BWC Version Matrix ([#810](https://github.com/opensearch-project/k-NN/pull/810)) * Update BWC Version with OpenSearch Version Bump ([#813](https://github.com/opensearch-project/k-NN/pull/813)) * Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811)) +* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827)) ### Documentation ### Maintenance ### Refactoring diff --git a/build.gradle b/build.gradle index f8a87db4e..7b7c54809 100644 --- a/build.gradle +++ b/build.gradle @@ -178,6 +178,7 @@ dependencies { testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.14.3' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2' testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.3' + api "org.opensearch:common-utils:${version}" } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index f8e832c50..e052f6dcf 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -182,11 +182,6 @@ public static void wipeAllModels() throws IOException { deleteKNNModel(TEST_MODEL_ID); deleteKNNModel(TEST_MODEL_ID_DEFAULT); deleteKNNModel(TEST_MODEL_ID_TRAINING); - - Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME); - - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0d5d75d30..cf0dd1890 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -13,6 +13,7 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; +import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -42,6 +43,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.transport.DeleteModelResponse; @@ -49,10 +51,10 @@ import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import java.io.IOException; import java.net.URL; @@ -62,6 +64,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; import static java.util.Objects.isNull; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; @@ -216,14 +219,21 @@ public void create(ActionListener actionListener) throws IO if (isCreated()) { return; } - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) - .settings( - Settings.builder() - .put("index.hidden", true) - .put("index.number_of_shards", this.numberOfShards) - .put("index.number_of_replicas", this.numberOfReplicas) - ); - client.admin().indices().create(request, actionListener); + runWithStashedThreadContext(() -> { + CreateIndexRequest request; + try { + request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) + .settings( + Settings.builder() + .put("index.hidden", true) + .put("index.number_of_shards", this.numberOfShards) + .put("index.number_of_replicas", this.numberOfReplicas) + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + client.admin().indices().create(request, actionListener); + }); } @Override @@ -293,8 +303,9 @@ private void putInternal(Model model, ActionListener listener, Do parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model); } - IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME); - + final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext( + () -> client.prepareIndex(MODEL_INDEX_NAME) + ); indexRequestBuilder.setId(model.getModelID()); indexRequestBuilder.setSource(parameters); @@ -304,8 +315,8 @@ private void putInternal(Model model, ActionListener listener, Do // After metadata update finishes, remove item from every node's cache if necessary. If no model id is // passed then nothing needs to be removed from the cache ActionListener onMetaListener; - onMetaListener = ActionListener.wrap( - indexResponse -> client.execute( + onMetaListener = ActionListener.wrap(indexResponse -> { + client.execute( RemoveModelFromCacheAction.INSTANCE, new RemoveModelFromCacheRequest(model.getModelID()), ActionListener.wrap(removeModelFromCacheResponse -> { @@ -318,9 +329,8 @@ private void putInternal(Model model, ActionListener listener, Do listener.onFailure(new RuntimeException(failureMessage)); }, listener::onFailure) - ), - listener::onFailure - ); + ); + }, listener::onFailure); // After the model is indexed, update metadata only if the model is in CREATED state ActionListener onIndexListener; @@ -357,16 +367,30 @@ private ActionListener getUpdateModelMetadataListener( ); } + @SneakyThrows @Override - public Model get(String modelId) throws ExecutionException, InterruptedException { + public Model get(String modelId) { /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - GetResponse getResponse = getRequestBuilder.execute().get(); - Map responseMap = getResponse.getSourceAsMap(); - return Model.getModelFromSourceMap(responseMap); + try { + return ModelDao.runWithStashedThreadContext(() -> { + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); + GetResponse getResponse; + try { + getResponse = getRequestBuilder.execute().get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + Map responseMap = getResponse.getSourceAsMap(); + return Model.getModelFromSourceMap(responseMap); + }); + } catch (RuntimeException runtimeException) { + // we need to use RuntimeException as container for real exception to keep signature + // of runWithStashedThreadContext generic + throw runtimeException.getCause(); + } } /** @@ -380,20 +404,22 @@ public void get(String modelId, ActionListener actionListener) /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - - getRequestBuilder.execute(ActionListener.wrap(response -> { - if (response.isSourceEmpty()) { - String errorMessage = String.format("Model \" %s \" does not exist", modelId); - actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); - return; - } - final Map responseMap = response.getSourceAsMap(); - Model model = Model.getModelFromSourceMap(responseMap); - actionListener.onResponse(new GetModelResponse(model)); + ModelDao.runWithStashedThreadContext(() -> { + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); + + getRequestBuilder.execute(ActionListener.wrap(response -> { + if (response.isSourceEmpty()) { + String errorMessage = String.format("Model \" %s \" does not exist", modelId); + actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); + return; + } + final Map responseMap = response.getSourceAsMap(); + Model model = Model.getModelFromSourceMap(responseMap); + actionListener.onResponse(new GetModelResponse(model)); - }, actionListener::onFailure)); + }, actionListener::onFailure)); + }); } /** @@ -404,8 +430,10 @@ public void get(String modelId, ActionListener actionListener) */ @Override public void search(SearchRequest request, ActionListener actionListener) { - request.indices(MODEL_INDEX_NAME); - client.search(request, actionListener); + ModelDao.runWithStashedThreadContext(() -> { + request.indices(MODEL_INDEX_NAME); + client.search(request, actionListener); + }); } @Override @@ -505,16 +533,17 @@ public void delete(String modelId, ActionListener listener) ); // Setup delete model request - DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); - deleteRequestBuilder.setId(modelId); - deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - // On model metadata removal, delete the model from the index - clearModelMetadataStep.whenComplete( - acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), - listener::onFailure - ); - + ModelDao.runWithStashedThreadContext(() -> { + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); + deleteRequestBuilder.setId(modelId); + deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // On model metadata removal, delete the model from the index + clearModelMetadataStep.whenComplete( + acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), + listener::onFailure + ); + }); deleteModelFromIndexStep.whenComplete(deleteResponse -> { // If model is not deleted, remove modelId from model graveyard and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { @@ -653,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache return stringBuilder.toString(); } } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + private static void runWithStashedThreadContext(Runnable function) { + try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + private static T runWithStashedThreadContext(Supplier function) { + try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index aaa64625e..12d45d8a3 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -19,15 +19,11 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.rest.RestStatus; -import java.io.IOException; +import java.util.List; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; @@ -48,59 +44,92 @@ public class RestDeleteModelHandlerIT extends KNNRestTestCase { - private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); - } - - public void testDeleteModelExists() throws IOException { + public void testDeleteModelExists() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - assertEquals(getDocCount(MODEL_INDEX_NAME), 1); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); - Request request = new Request("DELETE", restURI); + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Response getModelResponse = getModel(modelId, List.of()); + assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())); - assertEquals(0, getDocCount(MODEL_INDEX_NAME)); + String responseBody = EntityUtils.toString(getModelResponse.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); + + String deleteModelRestURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); + Request deleteModelRequest = new Request("DELETE", deleteModelRestURI); + + Response deleteModelResponse = client().performRequest(deleteModelRequest); + assertEquals( + deleteModelRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(deleteModelResponse.getStatusLine().getStatusCode()) + ); + + ResponseException ex = expectThrows(ResponseException.class, () -> getModel(modelId, List.of())); + assertTrue(ex.getMessage().contains(modelId)); } - public void testDeleteTrainingModel() throws IOException { + public void testDeleteTrainingModel() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - testModelMetadata.setState(ModelState.TRAINING); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - assertEquals(1, getDocCount(MODEL_INDEX_NAME)); - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); - Request request = new Request("DELETE", restURI); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + // we do not wait for training to be completed + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription); - assertEquals(1, getDocCount(MODEL_INDEX_NAME)); + Response getModelResponse = getModel(modelId, List.of()); + assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())); - String responseBody = EntityUtils.toString(response.getEntity()); + String responseBody = EntityUtils.toString(getModelResponse.getEntity()); assertNotNull(responseBody); Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - assertEquals(testModelID, responseMap.get(MODEL_ID)); + assertEquals(modelId, responseMap.get(MODEL_ID)); + + String deleteModelRestURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); + Request deleteModelRequest = new Request("DELETE", deleteModelRestURI); + + Response deleteModelResponse = client().performRequest(deleteModelRequest); + assertEquals( + deleteModelRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(deleteModelResponse.getStatusLine().getStatusCode()) + ); + + responseBody = EntityUtils.toString(deleteModelResponse.getEntity()); + assertNotNull(responseBody); + + responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); - String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", testModelID); + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); assertEquals(errorMessage, responseMap.get(DeleteModelResponse.ERROR_MSG)); + + // need to wait for training operation as it's required for after test cleanup + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } - public void testDeleteModelFailsInvalid() throws IOException { + public void testDeleteModelFailsInvalid() throws Exception { String modelId = "invalid-model-id"; createModelSystemIndex(); String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); @@ -111,7 +140,7 @@ public void testDeleteModelFailsInvalid() throws IOException { } // Test Train Model -> Delete Model -> Train Model with same modelId - public void testTrainingDeletedModel() throws IOException, InterruptedException { + public void testTrainingDeletedModel() throws Exception { String modelId = "test-model-id1"; String trainingIndexName1 = "train-index-1"; String trainingIndexName2 = "train-index-2"; @@ -134,8 +163,7 @@ public void testTrainingDeletedModel() throws IOException, InterruptedException trainModel(modelId, trainingIndexName2, trainingFieldName, dimension); } - private void trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension) throws IOException, - InterruptedException { + private void trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension) throws Exception { // Create a training index and randomly ingest data into it createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java index b6853e8bb..092ca31e3 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java @@ -18,10 +18,6 @@ import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.rest.RestStatus; @@ -39,6 +35,8 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.index.SpaceType.L2; +import static org.opensearch.knn.index.util.KNNEngine.FAISS; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestGetModelHandler} @@ -46,19 +44,28 @@ public class RestGetModelHandlerIT extends KNNRestTestCase { - private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); - } - - public void testGetModelExists() throws IOException { + public void testGetModelExists() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + + ingestDataAndTrainModel( + modelId, + trainingIndexName, + trainingFieldName, + dimension, + modelDescription, + xContentBuilderToMap(getModelMethodBuilder()) + ); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); Request request = new Request("GET", restURI); Response response = client().performRequest(request); @@ -68,30 +75,30 @@ public void testGetModelExists() throws IOException { assertNotNull(responseBody); Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - - assertEquals(testModelID, responseMap.get(MODEL_ID)); - assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION)); - assertEquals(testModelMetadata.getDimension(), responseMap.get(DIMENSION)); - assertEquals(testModelMetadata.getError(), responseMap.get(MODEL_ERROR)); - assertEquals(testModelMetadata.getKnnEngine().getName(), responseMap.get(KNN_ENGINE)); - assertEquals(testModelMetadata.getSpaceType().getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE)); - assertEquals(testModelMetadata.getState().getName(), responseMap.get(MODEL_STATE)); - assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP)); + assertEquals(modelId, responseMap.get(MODEL_ID)); + assertEquals(modelDescription, responseMap.get(MODEL_DESCRIPTION)); + assertEquals(FAISS.getName(), responseMap.get(KNN_ENGINE)); + assertEquals(L2.getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE)); } - public void testGetModelExistsWithFilter() throws IOException { + public void testGetModelExistsWithFilter() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + Map method = xContentBuilderToMap(getModelMethodBuilder()); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); Request request = new Request("GET", restURI); - List filterdPath = Arrays.asList(MODEL_ID, MODEL_DESCRIPTION, MODEL_TIMESTAMP, KNN_ENGINE); - request.addParameter("filter_path", Strings.join(filterdPath, ",")); + List filteredPath = Arrays.asList(MODEL_ID, MODEL_DESCRIPTION, MODEL_TIMESTAMP, KNN_ENGINE); + request.addParameter("filter_path", Strings.join(filteredPath, ",")); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -101,11 +108,10 @@ public void testGetModelExistsWithFilter() throws IOException { Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - assertTrue(responseMap.size() == filterdPath.size()); - assertEquals(testModelID, responseMap.get(MODEL_ID)); - assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION)); - assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP)); - assertEquals(testModelMetadata.getKnnEngine().getName(), responseMap.get(KNN_ENGINE)); + assertTrue(responseMap.size() == filteredPath.size()); + assertEquals(modelId, responseMap.get(MODEL_ID)); + assertEquals(modelDescription, responseMap.get(MODEL_DESCRIPTION)); + assertEquals(FAISS.getName(), responseMap.get(KNN_ENGINE)); assertFalse(responseMap.containsKey(DIMENSION)); assertFalse(responseMap.containsKey(MODEL_ERROR)); assertFalse(responseMap.containsKey(METHOD_PARAMETER_SPACE_TYPE)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index a1756cbf1..6ec699d87 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -48,6 +48,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -341,20 +342,23 @@ public void testScriptStats_multipleShards() throws Exception { public void testModelIndexHealthMetricsStats() throws IOException { // Create request that filters only model index String modelIndexStatusName = StatNames.MODEL_INDEX_STATUS.getName(); + // index can be created in one of previous tests, and as we do not delete it each test the check below became optional + if (!systemIndexExists(MODEL_INDEX_NAME)) { - Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); - String responseBody = EntityUtils.toString(response.getEntity()); - Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + final Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); + final String responseBody = EntityUtils.toString(response.getEntity()); + final Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - // Check that model health status is null since model index is not created to system yet - assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName())); + // Check that model health status is null since model index is not created to system yet + assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName())); - createModelSystemIndex(); + createModelSystemIndex(); + } - response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); + Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); - responseBody = EntityUtils.toString(response.getEntity()); - statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + final String responseBody = EntityUtils.toString(response.getEntity()); + final Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); // Check that model health status is not null assertNotNull(statsMap.get(modelIndexStatusName)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java index a4243537d..0d900cfbe 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java @@ -319,10 +319,15 @@ public void testScriptStats_multipleShards() throws Exception { // Useful settings when debugging to prevent timeouts @Override protected Settings restClientSettings() { + final Settings.Builder builder = Settings.builder(); if (isDebuggingTest || isDebuggingRemoteCluster) { - return Settings.builder().put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)).build(); + builder.put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)); } else { - return super.restClientSettings(); + if (System.getProperty("tests.rest.client_path_prefix") != null) { + builder.put(CLIENT_PATH_PREFIX, System.getProperty("tests.rest.client_path_prefix")); + } } + builder.put("strictDeprecationMode", false); + return builder.build(); } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 609fe7f09..92834217e 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -16,7 +16,6 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; @@ -39,6 +38,8 @@ import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; +import static org.opensearch.knn.index.SpaceType.L2; +import static org.opensearch.knn.index.util.KNNEngine.FAISS; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler} @@ -96,15 +97,25 @@ public void testSizeValidationFailsInvalidSize() throws IOException { } - public void testSearchModelExists() throws IOException { + public void testSearchModelExists() throws Exception { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + for (String modelId : testModelID) { + ingestDataAndTrainModel( + modelId, + trainingIndex, + trainingFieldName, + dimension, + modelDescription, + xContentBuilderToMap(getModelMethodBuilder()) + ); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -128,21 +139,25 @@ public void testSearchModelExists() throws IOException { for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); Model model = Model.getModelFromSourceMap(hit.getSourceAsMap()); - assertEquals(getModelMetadata(), model.getModelMetadata()); - assertArrayEquals(testModelBlob, model.getModelBlob()); + assertEquals(modelDescription, model.getModelMetadata().getDescription()); + assertEquals(FAISS, model.getModelMetadata().getKnnEngine()); + assertEquals(L2, model.getModelMetadata().getSpaceType()); } } } - public void testSearchModelWithoutSource() throws IOException { + public void testSearchModelWithoutSource() throws Exception { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -163,24 +178,27 @@ public void testSearchModelWithoutSource() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); assertNull(hit.getSourceAsMap()); } } } - public void testSearchModelWithSourceFilteringIncludes() throws IOException { + public void testSearchModelWithSourceFilteringIncludes() throws Exception { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -208,10 +226,10 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); assertTrue(sourceAsMap.containsKey("state")); @@ -221,15 +239,18 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { } } - public void testSearchModelWithSourceFilteringExcludes() throws IOException { + public void testSearchModelWithSourceFilteringExcludes() throws Exception { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -257,10 +278,10 @@ public void testSearchModelWithSourceFilteringExcludes() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); assertTrue(sourceAsMap.containsKey("state")); diff --git a/src/test/resources/security/sample.pem b/src/test/resources/security/sample.pem new file mode 100644 index 000000000..fa785ca10 --- /dev/null +++ b/src/test/resources/security/sample.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyTCCA7GgAwIBAgIGAWLrc1O2MA0GCSqGSIb3DQEBCwUAMIGPMRMwEQYKCZIm +iZPyLGQBGRYDY29tMRcwFQYKCZImiZPyLGQBGRYHZXhhbXBsZTEZMBcGA1UECgwQ +RXhhbXBsZSBDb20gSW5jLjEhMB8GA1UECwwYRXhhbXBsZSBDb20gSW5jLiBSb290 +IENBMSEwHwYDVQQDDBhFeGFtcGxlIENvbSBJbmMuIFJvb3QgQ0EwHhcNMTgwNDIy +MDM0MzQ3WhcNMjgwNDE5MDM0MzQ3WjBeMRIwEAYKCZImiZPyLGQBGRYCZGUxDTAL +BgNVBAcMBHRlc3QxDTALBgNVBAoMBG5vZGUxDTALBgNVBAsMBG5vZGUxGzAZBgNV +BAMMEm5vZGUtMC5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAJa+f476vLB+AwK53biYByUwN+40D8jMIovGXm6wgT8+9Sbs899dDXgt +9CE1Beo65oP1+JUz4c7UHMrCY3ePiDt4cidHVzEQ2g0YoVrQWv0RedS/yx/DKhs8 +Pw1O715oftP53p/2ijD5DifFv1eKfkhFH+lwny/vMSNxellpl6NxJTiJVnQ9HYOL +gf2t971ITJHnAuuxUF48HcuNovW4rhtkXef8kaAN7cE3LU+A9T474ULNCKkEFPIl +ZAKN3iJNFdVsxrTU+CUBHzk73Do1cCkEvJZ0ZFjp0Z3y8wLY/gqWGfGVyA9l2CUq +eIZNf55PNPtGzOrvvONiui48vBKH1LsCAwEAAaOCAVkwggFVMIG8BgNVHSMEgbQw +gbGAFJI1DOAPHitF9k0583tfouYSl0BzoYGVpIGSMIGPMRMwEQYKCZImiZPyLGQB +GRYDY29tMRcwFQYKCZImiZPyLGQBGRYHZXhhbXBsZTEZMBcGA1UECgwQRXhhbXBs +ZSBDb20gSW5jLjEhMB8GA1UECwwYRXhhbXBsZSBDb20gSW5jLiBSb290IENBMSEw +HwYDVQQDDBhFeGFtcGxlIENvbSBJbmMuIFJvb3QgQ0GCAQEwHQYDVR0OBBYEFKyv +78ZmFjVKM9g7pMConYH7FVBHMAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/BAQDAgXg +MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjA1BgNVHREELjAsiAUq +AwQFBYISbm9kZS0wLmV4YW1wbGUuY29tgglsb2NhbGhvc3SHBH8AAAEwDQYJKoZI +hvcNAQELBQADggEBAIOKuyXsFfGv1hI/Lkpd/73QNqjqJdxQclX57GOMWNbOM5H0 +5/9AOIZ5JQsWULNKN77aHjLRr4owq2jGbpc/Z6kAd+eiatkcpnbtbGrhKpOtoEZy +8KuslwkeixpzLDNISSbkeLpXz4xJI1ETMN/VG8ZZP1bjzlHziHHDu0JNZ6TnNzKr +XzCGMCohFfem8vnKNnKUneMQMvXd3rzUaAgvtf7Hc2LTBlf4fZzZF1EkwdSXhaMA +1lkfHiqOBxtgeDLxCHESZ2fqgVqsWX+t3qHQfivcPW6txtDyrFPRdJOGhiMGzT/t +e/9kkAtQRgpTb3skYdIOOUOV0WGQ60kJlFhAzIs= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/src/test/resources/security/test-kirk.jks b/src/test/resources/security/test-kirk.jks new file mode 100644 index 000000000..174dbda65 Binary files /dev/null and b/src/test/resources/security/test-kirk.jks differ diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 6ac7e63a1..89bc030eb 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -63,8 +63,12 @@ import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; @@ -92,7 +96,9 @@ import static org.opensearch.knn.TestUtils.QUERY_VALUE; import static org.opensearch.knn.TestUtils.computeGroundTruthValues; +import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; +import static org.opensearch.knn.index.util.KNNEngine.FAISS; import static org.opensearch.knn.plugin.stats.StatNames.INDICES_IN_CACHE; /** @@ -103,6 +109,8 @@ public class KNNRestTestCase extends ODFERestTestCase { public static final String FIELD_NAME = "test_field"; private static final String DOCUMENT_FIELD_SOURCE = "_source"; private static final String DOCUMENT_FIELD_FOUND = "found"; + protected static final int DELAY_MILLI_SEC = 1000; + protected static final int NUM_OF_ATTEMPTS = 30; @AfterClass public static void dumpCoverage() throws IOException, MalformedObjectNameException { @@ -638,7 +646,9 @@ protected void createModelSystemIndex() throws IOException { String mapping = Resources.toString(url, Charsets.UTF_8); mapping = mapping.substring(1, mapping.length() - 1); - createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping); + if (!systemIndexExists(MODEL_INDEX_NAME)) { + createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping); + } } protected void addModelToSystemIndex(String modelId, ModelMetadata modelMetadata, byte[] model) throws IOException { @@ -1164,6 +1174,83 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis) fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } + protected boolean systemIndexExists(final String indexName) throws IOException { + Response response = adminClient().performRequest(new Request("HEAD", "/" + indexName)); + return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode(); + } + + protected Settings.Builder noStrictDeprecationModeSettingsBuilder() { + Settings.Builder builder = Settings.builder().put("strictDeprecationMode", false); + if (System.getProperty("tests.rest.client_path_prefix") != null) { + builder.put(CLIENT_PATH_PREFIX, System.getProperty("tests.rest.client_path_prefix")); + } + return builder; + } + + protected void ingestDataAndTrainModel( + String modelId, + String trainingIndexName, + String trainingFieldName, + int dimension, + String modelDescription + ) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + + Map method = xContentBuilderToMap(builder); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method); + } + + protected void ingestDataAndTrainModel( + String modelId, + String trainingIndexName, + String trainingFieldName, + int dimension, + String modelDescription, + Map method + ) throws Exception { + int trainingDataCount = 40; + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, modelDescription); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + } + + protected XContentBuilder getModelMethodBuilder() throws IOException { + XContentBuilder modelMethodBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, FAISS.getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, L2.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + return modelMethodBuilder; + } + /** * We need to be able to dump the jacoco coverage before cluster is shut down. * The new internal testing framework removed some of the gradle tasks we were listening to diff --git a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java index 5f174b964..097fe014d 100644 --- a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java @@ -5,13 +5,6 @@ package org.opensearch.knn; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - import org.apache.http.Header; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; @@ -21,23 +14,54 @@ import org.apache.http.impl.client.BasicCredentialsProvider; import org.apache.http.message.BasicHeader; import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.Strings; +import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; import org.opensearch.test.rest.OpenSearchRestTestCase; -import org.junit.After; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; import static org.opensearch.knn.TestUtils.KNN_BWC_PREFIX; import static org.opensearch.knn.TestUtils.OPENDISTRO_SECURITY; +import static org.opensearch.knn.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; +import static org.opensearch.knn.TestUtils.SECURITY_AUDITLOG_PREFIX; import static org.opensearch.knn.TestUtils.SKIP_DELETE_MODEL_INDEX; +import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; /** @@ -45,6 +69,8 @@ */ public abstract class ODFERestTestCase extends OpenSearchRestTestCase { + private final Set IMMUTABLE_INDEX_PREFIXES = Set.of(KNN_BWC_PREFIX, SECURITY_AUDITLOG_PREFIX, OPENSEARCH_SYSTEM_INDEX_PREFIX); + protected boolean isHttps() { boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false); if (isHttps) { @@ -66,7 +92,22 @@ protected String getProtocol() { protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { RestClientBuilder builder = RestClient.builder(hosts); if (isHttps()) { - configureHttpsClient(builder, settings); + String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH); + if (Objects.nonNull(keystore)) { + URI uri; + try { + uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + Path configPath = PathUtils.get(uri).getParent().toAbsolutePath(); + return new SecureRestClientBuilder(settings, configPath).build(); + } else { + configureHttpsClient(builder, settings); + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } } else { configureClient(builder, settings); } @@ -120,8 +161,8 @@ protected boolean preserveIndicesUponCompletion() { @SuppressWarnings("unchecked") @After - protected void wipeAllODFEIndices() throws IOException { - Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); + protected void wipeAllODFEIndices() throws Exception { + Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); XContentType xContentType = XContentType.fromMediaType(response.getEntity().getContentType().getValue()); try ( XContentParser parser = xContentType.xContent() @@ -140,7 +181,11 @@ protected void wipeAllODFEIndices() throws IOException { } for (Map index : parserList) { - String indexName = (String) index.get("index"); + final String indexName = (String) index.get("index"); + if (isIndexCleanupRequired(indexName)) { + wipeIndexContent(indexName); + continue; + } if (!skipDeleteIndex(indexName)) { adminClient().performRequest(new Request("DELETE", "/" + indexName)); } @@ -148,6 +193,57 @@ protected void wipeAllODFEIndices() throws IOException { } } + private boolean isIndexCleanupRequired(final String index) { + return MODEL_INDEX_NAME.equals(index) && !getSkipDeleteModelIndexFlag(); + } + + private void wipeIndexContent(String indexName) throws IOException { + deleteModels(getModelIds()); + deleteAllDocs(indexName); + } + + private List getModelIds() throws IOException { + final String restURIGetModels = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); + final Response response = adminClient().performRequest(new Request("GET", restURIGetModels)); + + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + final String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + final XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody); + final SearchResponse searchResponse = SearchResponse.fromXContent(parser); + + return Arrays.stream(searchResponse.getHits().getHits()).map(SearchHit::getId).collect(Collectors.toList()); + } + + private void deleteModels(final List modelIds) throws IOException { + for (final String testModelID : modelIds) { + final String restURIGetModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + final Response getModelResponse = adminClient().performRequest(new Request("GET", restURIGetModel)); + if (RestStatus.OK != RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())) { + continue; + } + final String restURIDeleteModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + adminClient().performRequest(new Request("DELETE", restURIDeleteModel)); + } + } + + private void deleteAllDocs(final String indexName) throws IOException { + final String restURIDeleteByQuery = String.join("/", indexName, "_delete_by_query"); + final Request request = new Request("POST", restURIDeleteByQuery); + final XContentBuilder matchAllDocsQuery = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + + request.setJsonEntity(Strings.toString(matchAllDocsQuery)); + adminClient().performRequest(request); + } + private boolean getSkipDeleteModelIndexFlag() { return Boolean.parseBoolean(System.getProperty(SKIP_DELETE_MODEL_INDEX, "false")); } @@ -159,11 +255,25 @@ private boolean skipDeleteModelIndex(String indexName) { private boolean skipDeleteIndex(String indexName) { if (indexName != null && !OPENDISTRO_SECURITY.equals(indexName) - && !indexName.startsWith(KNN_BWC_PREFIX) + && IMMUTABLE_INDEX_PREFIXES.stream().noneMatch(indexName::startsWith) && !skipDeleteModelIndex(indexName)) { return false; } return true; } + + @Override + protected Settings restAdminSettings() { + return Settings.builder() + // disable the warning exception for admin client since it's only used for cleanup. + .put("strictDeprecationMode", false) + .put("http.port", 9200) + .put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps()) + .put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit") + .build(); + } } diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index f179eef36..0843176e7 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -95,6 +95,8 @@ public class TestUtils { public static final String ROLLING_UPGRADE_FIRST_ROUND = "tests.rest.first_round"; public static final String SKIP_DELETE_MODEL_INDEX = "tests.skip_delete_model_index"; public static final String UPGRADED_CLUSTER = "upgraded_cluster"; + public static final String SECURITY_AUDITLOG_PREFIX = "security-auditlog"; + public static final String OPENSEARCH_SYSTEM_INDEX_PREFIX = ".opensearch"; // Generating vectors using random function with a seed which makes these vectors standard and generate same vectors for each run. public static float[][] randomlyGenerateStandardVectors(int numVectors, int dimensions, int seed) {