Skip to content

Commit

Permalink
Support .opensearch-knn-model index as system index with security ena…
Browse files Browse the repository at this point in the history
…bled (opensearch-project#827)

* Add support for integ tests on secured cluster

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
(cherry picked from commit b94b030)
  • Loading branch information
martin-gaievski committed Apr 5, 2023
1 parent 66664f9 commit 80d2061
Show file tree
Hide file tree
Showing 13 changed files with 516 additions and 195 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add 2.6.0 to BWC Version Matrix ([#810](https://github.com/opensearch-project/k-NN/pull/810))
* Update BWC Version with OpenSearch Version Bump ([#813](https://github.com/opensearch-project/k-NN/pull/813))
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
### Documentation
### Maintenance
### Refactoring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ public static void wipeAllModels() throws IOException {
deleteKNNModel(TEST_MODEL_ID);
deleteKNNModel(TEST_MODEL_ID_DEFAULT);
deleteKNNModel(TEST_MODEL_ID_TRAINING);

Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

Expand Down
147 changes: 99 additions & 48 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand Down Expand Up @@ -42,17 +43,18 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

import java.io.IOException;
import java.net.URL;
Expand All @@ -62,6 +64,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
Expand Down Expand Up @@ -216,14 +219,21 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
client.admin().indices().create(request, actionListener);
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
});
}

@Override
Expand Down Expand Up @@ -293,8 +303,9 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);

final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext(
() -> client.prepareIndex(MODEL_INDEX_NAME)
);
indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

Expand All @@ -304,8 +315,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// After metadata update finishes, remove item from every node's cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(
indexResponse -> client.execute(
onMetaListener = ActionListener.wrap(indexResponse -> {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
Expand All @@ -318,9 +329,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
),
listener::onFailure
);
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
Expand Down Expand Up @@ -357,16 +367,30 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) throws ExecutionException, InterruptedException {
public Model get(String modelId) {
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
try {
return ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
try {
getResponse = getRequestBuilder.execute().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
});
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
throw runtimeException.getCause();
}
}

/**
Expand All @@ -380,20 +404,22 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
}, actionListener::onFailure));
});
}

/**
Expand All @@ -404,8 +430,10 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
ModelDao.runWithStashedThreadContext(() -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
});
}

@Override
Expand Down Expand Up @@ -505,16 +533,17 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);

ModelDao.runWithStashedThreadContext(() -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
});
deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
Expand Down Expand Up @@ -653,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
Loading

0 comments on commit 80d2061

Please sign in to comment.