From 75a772c493ef2ef67db53051512a95cbb3bc123d Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 31 Mar 2023 17:25:58 -0700 Subject: [PATCH] Stashing thread context to allow model operations to regular client Signed-off-by: Martin Gaievski --- .../org/opensearch/knn/indices/ModelDao.java | 19 ++++++++++++------- .../transport/DeleteModelTransportAction.java | 15 ++++++++++++--- .../transport/GetModelTransportAction.java | 17 +++++++++++++---- .../transport/SearchModelTransportAction.java | 10 ++++++++-- .../TrainingJobRouterTransportAction.java | 14 ++++++++++---- .../TrainingModelTransportAction.java | 15 ++++++++++++--- .../plugin/action/RestGetModelHandlerIT.java | 9 +-------- 7 files changed, 68 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0d5d75d30f..25b8b36c1d 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -42,6 +42,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.transport.DeleteModelResponse; @@ -362,11 +363,13 @@ public Model get(String modelId) throws ExecutionException, InterruptedException /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - GetResponse getResponse = getRequestBuilder.execute().get(); - Map responseMap = getResponse.getSourceAsMap(); - return Model.getModelFromSourceMap(responseMap); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); + GetResponse getResponse = getRequestBuilder.execute().get(); + Map responseMap = getResponse.getSourceAsMap(); + return Model.getModelFromSourceMap(responseMap); + } } /** @@ -404,8 +407,10 @@ public void get(String modelId, ActionListener actionListener) */ @Override public void search(SearchRequest request, ActionListener actionListener) { - request.indices(MODEL_INDEX_NAME); - client.search(request, actionListener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + request.indices(MODEL_INDEX_NAME); + client.search(request, actionListener); + } } @Override diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java index ee7f9e9397..16d2b0865f 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java @@ -14,7 +14,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -22,16 +24,23 @@ public class DeleteModelTransportAction extends HandledTransportAction { private final ModelDao modelDao; + private final Client client; @Inject - public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) { + public DeleteModelTransportAction(TransportService transportService, ActionFilters filters, Client client) { super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new); this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + this.client = client; } @Override protected void doExecute(Task task, DeleteModelRequest request, ActionListener listener) { - String modelID = request.getModelID(); - modelDao.delete(modelID, listener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String modelID = request.getModelID(); + modelDao.delete(modelID, listener); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java index e47a42d8d7..7449ff53a4 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java @@ -15,7 +15,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -27,17 +29,24 @@ public class GetModelTransportAction extends HandledTransportAction actionListener) { - String modelID = request.getModelID(); - - modelDao.get(modelID, actionListener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String modelID = request.getModelID(); + modelDao.get(modelID, actionListener); + } catch (Exception e) { + logger.error(e); + actionListener.onFailure(e); + } } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java index 4d9f670596..70d93be0ec 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java @@ -16,7 +16,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -26,17 +28,21 @@ public class SearchModelTransportAction extends HandledTransportAction { private ModelDao modelDao; + private final Client client; + @Inject - public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters) { + public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { super(SearchModelAction.NAME, transportService, actionFilters, SearchRequest::new); this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + this.client = client; } @Override protected void doExecute(Task task, SearchRequest request, ActionListener listener) { - try { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { this.modelDao.search(request, listener); } catch (IOException e) { + logger.error(e); listener.onFailure(e); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 774029c58a..c3fc7fa977 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; @@ -58,10 +59,15 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener // Get the size of the training request and then route the request. We get/set this here, as opposed to in // TrainingModelTransportAction, because in the future, we may want to use size to factor into our routing // decision. - getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> { - request.setTrainingDataSizeInKB(size); - routeRequest(request, listener); - }, listener::onFailure)); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> { + request.setTrainingDataSizeInKB(size); + routeRequest(request, listener); + }, listener::onFailure)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } } protected void routeRequest(TrainingModelRequest request, ActionListener listener) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index a3c4be16ec..87c8615d7b 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -14,8 +14,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -34,10 +36,18 @@ public class TrainingModelTransportAction extends HandledTransportAction method = xContentBuilderToMap(getModelMethodBuilder()); - ingestDataAndTrainModel( - modelId, - trainingIndexName, - trainingFieldName, - dimension, - modelDescription, - method - ); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method); assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);