Skip to content

Commit

Permalink
Stashing thread context to allow model operations to regular client
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Apr 1, 2023
1 parent b21b6e2 commit 75a772c
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 31 deletions.
19 changes: 12 additions & 7 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
Expand Down Expand Up @@ -362,11 +363,13 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse = getRequestBuilder.execute().get();
Map<String, Object> responseMap = getResponse.getSourceAsMap();
return Model.getModelFromSourceMap(responseMap);
}
}

/**
Expand Down Expand Up @@ -404,8 +407,10 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,33 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class DeleteModelTransportAction extends HandledTransportAction<DeleteModelRequest, DeleteModelResponse> {

private final ModelDao modelDao;
private final Client client;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) {
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters, Client client) {
super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, DeleteModelRequest request, ActionListener<DeleteModelResponse> listener) {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
} catch (Exception e) {
logger.error(e);
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -27,17 +29,24 @@ public class GetModelTransportAction extends HandledTransportAction<GetModelRequ
private static final Logger LOG = LogManager.getLogger(GetModelTransportAction.class);
private ModelDao modelDao;

private final Client client;

@Inject
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters) {
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(GetModelAction.NAME, transportService, actionFilters, GetModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, GetModelRequest request, ActionListener<GetModelResponse> actionListener) {
String modelID = request.getModelID();

modelDao.get(modelID, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
String modelID = request.getModelID();

modelDao.get(modelID, actionListener);
} catch (Exception e) {
logger.error(e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -26,17 +28,21 @@
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private ModelDao modelDao;

private final Client client;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters) {
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SearchModelAction.NAME, transportService, actionFilters, SearchRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> listener) {
try {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
this.modelDao.search(request, listener);
} catch (IOException e) {
logger.error(e);
listener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.collect.ImmutableOpenMap;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
Expand Down Expand Up @@ -58,10 +59,15 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
// Get the size of the training request and then route the request. We get/set this here, as opposed to in
// TrainingModelTransportAction, because in the future, we may want to use size to factor into our routing
// decision.
getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> {
request.setTrainingDataSizeInKB(size);
routeRequest(request, listener);
}, listener::onFailure));
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> {
request.setTrainingDataSizeInKB(size);
routeRequest(request, listener);
}, listener::onFailure));
} catch (Exception e) {
logger.error(e);
listener.onFailure(e);
}
}

protected void routeRequest(TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand All @@ -34,10 +36,18 @@ public class TrainingModelTransportAction extends HandledTransportAction<Trainin

private final ClusterService clusterService;

private final Client client;

@Inject
public TrainingModelTransportAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
public TrainingModelTransportAction(
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
Client client
) {
super(TrainingModelAction.NAME, transportService, actionFilters, TrainingModelRequest::new);
this.clusterService = clusterService;
this.client = client;
}

@Override
Expand Down Expand Up @@ -74,8 +84,7 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
KNNCounter.TRAINING_ERRORS.increment();
listener.onFailure(ex);
});

try {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
TrainingJobRunner.getInstance()
.execute(
trainingJob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,7 @@ public void testGetModelExistsWithFilter() throws Exception {

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
Map<String, Object> method = xContentBuilderToMap(getModelMethodBuilder());
ingestDataAndTrainModel(
modelId,
trainingIndexName,
trainingFieldName,
dimension,
modelDescription,
method
);
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method);
assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC);

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);
Expand Down

0 comments on commit 75a772c

Please sign in to comment.