Skip to content

Commit

Permalink
Abstract thread context stashing into a method
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>

Fixing issues related to exception handling

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Apr 4, 2023
1 parent 0fc30b9 commit fb1111d
Showing 1 changed file with 68 additions and 48 deletions.
116 changes: 68 additions & 48 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand Down Expand Up @@ -50,10 +51,10 @@
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

import java.io.IOException;
import java.net.URL;
Expand All @@ -63,6 +64,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
Expand Down Expand Up @@ -217,20 +219,21 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
} catch (Exception e) {
actionListener.onFailure(e);
}
});
}

@Override
Expand Down Expand Up @@ -300,13 +303,9 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder;
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);
}

final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext(
() -> client.prepareIndex(MODEL_INDEX_NAME)
);
indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

Expand Down Expand Up @@ -368,19 +367,29 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) throws ExecutionException, InterruptedException {
public Model get(String modelId) {
/*
GET /<model_index>/<modelId>?_local
*/
// temporary setting thread context to default, this is needed to allow actions on model system index when security plugin is
// enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
try {
return ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
try {
getResponse = getRequestBuilder.execute().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
});
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
throw runtimeException.getCause();
}
}

Expand All @@ -395,9 +404,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

Expand All @@ -412,9 +419,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
} catch (Exception e) {
actionListener.onFailure(e);
}
});
}

/**
Expand All @@ -425,12 +430,10 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
// temporary setting thread context to default, this is needed to allow actions on model system index when security plugin is
// enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ModelDao.runWithStashedThreadContext(() -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
}
});
}

@Override
Expand Down Expand Up @@ -530,9 +533,7 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ModelDao.runWithStashedThreadContext(() -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand All @@ -542,10 +543,7 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
} catch (Exception e) {
listener.onFailure(e);
}

});
deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
Expand Down Expand Up @@ -684,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}

0 comments on commit fb1111d

Please sign in to comment.