Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support .opensearch-knn-model index as system index with security enabled #827

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add CHANGELOG ([#800](https://github.com/opensearch-project/k-NN/pull/800))
* Bump byte-buddy version from 1.12.22 to 1.14.2 ([#804](https://github.com/opensearch-project/k-NN/pull/804))
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
### Documentation
### Maintenance
### Refactoring
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.14.3'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.3'
api "org.opensearch:common-utils:${version}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ public static void wipeAllModels() throws IOException {
deleteKNNModel(TEST_MODEL_ID);
deleteKNNModel(TEST_MODEL_ID_DEFAULT);
deleteKNNModel(TEST_MODEL_ID_TRAINING);

Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

Expand Down
147 changes: 99 additions & 48 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand Down Expand Up @@ -42,17 +43,18 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

import java.io.IOException;
import java.net.URL;
Expand All @@ -62,6 +64,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
Expand Down Expand Up @@ -216,14 +219,21 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
client.admin().indices().create(request, actionListener);
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
});
}

@Override
Expand Down Expand Up @@ -293,8 +303,9 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);

final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext(
() -> client.prepareIndex(MODEL_INDEX_NAME)
);
indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

Expand All @@ -304,8 +315,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// After metadata update finishes, remove item from every node's cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(
indexResponse -> client.execute(
onMetaListener = ActionListener.wrap(indexResponse -> {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
Expand All @@ -318,9 +329,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
),
listener::onFailure
);
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
Expand Down Expand Up @@ -357,16 +367,30 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) throws ExecutionException, InterruptedException {
public Model get(String modelId) {
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
try {
return ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
try {
getResponse = getRequestBuilder.execute().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
});
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
throw runtimeException.getCause();
}
}

/**
Expand All @@ -380,20 +404,22 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
}, actionListener::onFailure));
});
}

/**
Expand All @@ -404,8 +430,10 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
ModelDao.runWithStashedThreadContext(() -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
});
}

@Override
Expand Down Expand Up @@ -505,16 +533,17 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);

ModelDao.runWithStashedThreadContext(() -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
});
deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
Expand Down Expand Up @@ -653,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
Loading