Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Support .opensearch-knn-model index as system index with security enabled (#827) #842

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add 2.6.0 to BWC Version Matrix ([#810](https://github.com/opensearch-project/k-NN/pull/810))
* Update BWC Version with OpenSearch Version Bump ([#813](https://github.com/opensearch-project/k-NN/pull/813))
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
### Documentation
### Maintenance
### Refactoring
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -178,6 +178,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.14.3'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.3'
api "org.opensearch:common-utils:${version}"
}


Original file line number Diff line number Diff line change
@@ -182,11 +182,6 @@ public static void wipeAllModels() throws IOException {
deleteKNNModel(TEST_MODEL_ID);
deleteKNNModel(TEST_MODEL_ID_DEFAULT);
deleteKNNModel(TEST_MODEL_ID_TRAINING);

Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

147 changes: 99 additions & 48 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
@@ -42,17 +43,18 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

import java.io.IOException;
import java.net.URL;
@@ -62,6 +64,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
@@ -216,14 +219,21 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
client.admin().indices().create(request, actionListener);
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
});
}

@Override
@@ -293,8 +303,9 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);

final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext(
() -> client.prepareIndex(MODEL_INDEX_NAME)
);
indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

@@ -304,8 +315,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// After metadata update finishes, remove item from every node's cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(
indexResponse -> client.execute(
onMetaListener = ActionListener.wrap(indexResponse -> {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
@@ -318,9 +329,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
),
listener::onFailure
);
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
@@ -357,16 +367,30 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) throws ExecutionException, InterruptedException {
public Model get(String modelId) {
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
try {
return ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
try {
getResponse = getRequestBuilder.execute().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
});
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
throw runtimeException.getCause();
}
}

/**
@@ -380,20 +404,22 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
}, actionListener::onFailure));
});
}

/**
@@ -404,8 +430,10 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
ModelDao.runWithStashedThreadContext(() -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
});
}

@Override
@@ -505,16 +533,17 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);

ModelDao.runWithStashedThreadContext(() -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
});
deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
@@ -653,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
Loading