Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] 2.x consistent get interactions (#1334) #1432

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,17 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
internalListener.onResponse(true);
return;
}
GetRequest getRequest = Requests.getRequest(indexName).id(conversationId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) {
throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found");
}
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
internalListener.onResponse(true);
return;
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
String user = User.parse(userstr).getName();
// If you're not the owner of this conversation, you do not have permission
Expand All @@ -290,7 +290,13 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
}
internalListener.onResponse(true);
}, e -> { internalListener.onFailure(e); });
client.get(getRequest, al);
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> {
log.error("Failed to refresh conversations index during check access ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;

/**
* Class for handling all Conversational Memory operactions
*/
@Log4j2
public class OpenSearchConversationalMemoryHandler implements ConversationalMemoryHandler {

private ConversationMetaIndex conversationMetaIndex;
Expand Down Expand Up @@ -247,19 +250,25 @@ public ActionFuture<List<ConversationMeta>> getConversations(int maxResults) {
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
StepListener<Boolean> accessListener = new StepListener<>();
conversationMetaIndex.checkAccess(conversationId, accessListener);

log.info("DELETING CONVERSATION " + conversationId);
accessListener.whenComplete(access -> {
if (access) {
StepListener<Boolean> metaDeleteListener = new StepListener<>();
StepListener<Boolean> interactionsListener = new StepListener<>();

conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener);
interactionsIndex.deleteConversation(conversationId, interactionsListener);

metaDeleteListener.whenComplete(metaResult -> {
interactionsListener
.whenComplete(interactionResult -> { listener.onResponse(metaResult && interactionResult); }, listener::onFailure);
interactionsListener
.whenComplete(
interactionResult -> { conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); },
listener::onFailure
);

metaDeleteListener.whenComplete(metaDeleteResult -> {
log.info("SUCCESSFUL DELETION OF CONVERSATION " + conversationId);
listener.onResponse(metaDeleteResult && interactionsListener.result());
}, listener::onFailure);

} else {
listener.onResponse(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,18 @@ public void testCanDeleteConversations() {
});

StepListener<List<Interaction>> inters2 = new StepListener<>();
inters1.whenComplete(ints -> { cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }, e -> {
inters1.whenComplete(ints -> {
cdl.countDown();
assert (false);
}, e -> {
assert (e.getMessage().startsWith("Conversation ["));
cmHandler.getInteractions(cid2.result(), 0, 10, inters2);
});

LatchedActionListener<List<Interaction>> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(r -> {
assert (del.result());
assert (conversations.result().size() == 1);
assert (conversations.result().get(0).getId().equals(cid2.result()));
assert (inters1.result().size() == 0);
assert (inters2.result().size() == 1);
assert (inters2.result().get(0).getId().equals(iid3.result()));
}, e -> { assert (false); }), cdl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ public void testDelete_DeleteFails_ThenFail() {

public void testCheckAccess_DoesNotExist_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
GetResponse response = mock(GetResponse.class);
doReturn(false).when(response).isExists();
Expand All @@ -423,6 +424,7 @@ public void testCheckAccess_DoesNotExist_ThenFail() {

public void testCheckAccess_WrongId_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
GetResponse response = mock(GetResponse.class);
doReturn(true).when(response).isExists();
Expand All @@ -443,6 +445,7 @@ public void testCheckAccess_WrongId_ThenFail() {

public void testCheckAccess_GetFails_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
doAnswer(invocation -> {
ActionListener<GetResponse> al = invocation.getArgument(1);
Expand All @@ -459,6 +462,7 @@ public void testCheckAccess_GetFails_ThenFail() {

public void testCheckAccess_ClientFails_ThenFail() {
setupUser("user");
setupRefreshSuccess();
doReturn(true).when(metadata).hasIndex(anyString());
doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any());
@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.settings.MLCommonsSettings;
Expand Down Expand Up @@ -163,15 +164,20 @@ public void testDeleteConversation_WithInteractions() throws IOException {
assert (!gcmap.containsKey("next_token"));
assert (((ArrayList) gcmap.get("conversations")).size() == 0);

Response giresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
assert (giresponse != null);
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
HttpEntity gihttpEntity = giresponse.getEntity();
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
Map gimap = gson.fromJson(gientityString, Map.class);
assert (gimap.containsKey("interactions"));
assert (!gimap.containsKey("next_token"));
assert (((ArrayList) gimap.get("interactions")).size() == 0);
try {
Response giresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
assert (giresponse != null);
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
HttpEntity gihttpEntity = giresponse.getEntity();
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
Map gimap = gson.fromJson(gientityString, Map.class);
assert (gimap.containsKey("interactions"));
assert (!gimap.containsKey("next_token"));
assert (((ArrayList) gimap.get("interactions")).size() == 0);
assert (false);
} catch (ResponseException e) {
assert (TestHelper.restStatus(e.getResponse()) == RestStatus.NOT_FOUND);
}
}
}
Loading