-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for context_size and include 'interaction_id' in SearchRe… #1385
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
*/ | ||
package org.opensearch.searchpipelines.questionanswering.generative; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.JsonArray; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
@@ -41,10 +40,13 @@ | |
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; | ||
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; | ||
|
||
import java.time.Duration; | ||
import java.time.Instant; | ||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.function.BooleanSupplier; | ||
|
||
import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; | ||
|
@@ -58,11 +60,16 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements | |
|
||
private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10; | ||
|
||
private static final int MAX_PROCESSOR_TIME_IN_SECONDS = 60; | ||
|
||
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. | ||
|
||
private final String llmModel; | ||
private final List<String> contextFields; | ||
|
||
private final String systemPrompt; | ||
private final String userInstructions; | ||
|
||
@Setter | ||
private ConversationalMemoryClient memoryClient; | ||
|
||
|
@@ -74,10 +81,12 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements | |
private final BooleanSupplier featureFlagSupplier; | ||
|
||
protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, | ||
Llm llm, String llmModel, List<String> contextFields, BooleanSupplier supplier) { | ||
Llm llm, String llmModel, List<String> contextFields, String systemPrompt, String userInstructions, BooleanSupplier supplier) { | ||
super(tag, description, ignoreFailure); | ||
this.llmModel = llmModel; | ||
this.contextFields = contextFields; | ||
this.systemPrompt = systemPrompt; | ||
this.userInstructions = userInstructions; | ||
this.llm = llm; | ||
this.memoryClient = new ConversationalMemoryClient(client); | ||
this.featureFlagSupplier = supplier; | ||
|
@@ -93,45 +102,94 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp | |
} | ||
|
||
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); | ||
|
||
Integer timeout = params.getTimeout(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one liner?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, that would involve calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point. We could do:
But upto you, keep the code as it is if you want. |
||
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) { | ||
timeout = MAX_PROCESSOR_TIME_IN_SECONDS; | ||
} | ||
log.info("Timeout for this request: {} seconds.", timeout); | ||
|
||
String llmQuestion = params.getLlmQuestion(); | ||
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); | ||
if (llmModel == null) { | ||
throw new IllegalArgumentException("llm_model cannot be null."); | ||
} | ||
String conversationId = params.getConversationId(); | ||
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); | ||
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW); | ||
List<String> searchResults = getSearchResults(response); | ||
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults)); | ||
String answer = (String) output.getAnswers().get(0); | ||
Instant start = Instant.now(); | ||
Integer interactionSize = params.getInteractionSize(); | ||
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) { | ||
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW; | ||
} | ||
log.info("Using interaction size of {}", interactionSize); | ||
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, interactionSize); | ||
log.info("Retrieved chat history. ({})", getDuration(start)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious to know what's the goal of adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It logs the elapsed time for |
||
|
||
Integer topN = params.getContextSize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one liner?
|
||
if (topN == null) { | ||
topN = GenerativeQAParameters.SIZE_NULL_VALUE; | ||
} | ||
List<String> searchResults = getSearchResults(response, topN); | ||
|
||
log.info("system_prompt: {}", systemPrompt); | ||
log.info("user_instructions: {}", userInstructions); | ||
start = Instant.now(); | ||
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(systemPrompt, userInstructions, llmModel, | ||
llmQuestion, chatHistory, searchResults, timeout)); | ||
log.info("doChatCompletion complete. ({})", getDuration(start)); | ||
|
||
String answer = null; | ||
String errorMessage = null; | ||
String interactionId = null; | ||
if (conversationId != null) { | ||
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer, | ||
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults)); | ||
if (output.isErrorOccurred()) { | ||
errorMessage = output.getErrors().get(0); | ||
} else { | ||
answer = (String) output.getAnswers().get(0); | ||
|
||
if (conversationId != null) { | ||
start = Instant.now(); | ||
interactionId = memoryClient.createInteraction(conversationId, | ||
llmQuestion, | ||
PromptUtil.getPromptTemplate(systemPrompt, userInstructions), | ||
answer, | ||
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, | ||
jsonArrayToString(searchResults) | ||
); | ||
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); | ||
} | ||
} | ||
|
||
return insertAnswer(response, answer, interactionId); | ||
return insertAnswer(response, answer, errorMessage, interactionId); | ||
} | ||
|
||
long getDuration(Instant start) { | ||
return Duration.between(start, Instant.now()).toMillis(); | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; | ||
} | ||
|
||
private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) { | ||
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { | ||
|
||
// TODO return the interaction id in the response. | ||
|
||
return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), | ||
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); | ||
return new GenerativeSearchResponse(answer, errorMessage, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), | ||
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters(), interactionId); | ||
} | ||
|
||
private List<String> getSearchResults(SearchResponse response) { | ||
private List<String> getSearchResults(SearchResponse response, Integer topN) { | ||
List<String> searchResults = new ArrayList<>(); | ||
for (SearchHit hit : response.getHits().getHits()) { | ||
Map<String, Object> docSourceMap = hit.getSourceAsMap(); | ||
SearchHit[] hits = response.getHits().getHits(); | ||
int total = hits.length; | ||
int end = (topN != GenerativeQAParameters.SIZE_NULL_VALUE) ? Math.min(topN, total) : total; | ||
for (int i = 0; i < end; i++) { | ||
Map<String, Object> docSourceMap = hits[i].getSourceAsMap(); | ||
for (String contextField : contextFields) { | ||
Object context = docSourceMap.get(contextField); | ||
if (context == null) { | ||
log.error("Context " + contextField + " not found in search hit " + hit); | ||
log.error("Context " + contextField + " not found in search hit " + hits[i]); | ||
// TODO throw a more meaningful error here? | ||
throw new RuntimeException(); | ||
} | ||
|
@@ -189,14 +247,27 @@ public SearchResponseProcessor create( | |
"required property can't be empty." | ||
); | ||
} | ||
log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields); | ||
String systemPrompt = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, | ||
tag, | ||
config, | ||
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT | ||
); | ||
String userInstructions = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, | ||
tag, | ||
config, | ||
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS | ||
); | ||
log.info("model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}", | ||
modelId, llmModel, contextFields, systemPrompt, userInstructions); | ||
return new GenerativeQAResponseProcessor(client, | ||
tag, | ||
description, | ||
ignoreFailure, | ||
ModelLocator.getLlm(modelId, client), | ||
llmModel, | ||
contextFields, | ||
systemPrompt, | ||
userInstructions, | ||
featureFlagSupplier | ||
); | ||
} else { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import org.opensearch.core.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
/** | ||
* This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section. | ||
|
@@ -33,22 +34,32 @@ public class GenerativeSearchResponse extends SearchResponse { | |
|
||
private static final String EXT_SECTION_NAME = "ext"; | ||
private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer"; | ||
private static final String GENERATIVE_QA_ERROR_FIELD_NAME = "error"; | ||
private static final String INTERACTION_ID_FIELD_NAME = "interaction_id"; | ||
|
||
private final String answer; | ||
private String errorMessage; | ||
private final String interactionId; | ||
|
||
public GenerativeSearchResponse( | ||
String answer, | ||
String errorMessage, | ||
SearchResponseSections internalResponse, | ||
String scrollId, | ||
int totalShards, | ||
int successfulShards, | ||
int skippedShards, | ||
long tookInMillis, | ||
ShardSearchFailure[] shardFailures, | ||
Clusters clusters | ||
Clusters clusters, | ||
String interactionId | ||
) { | ||
super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); | ||
this.answer = answer; | ||
if (answer == null) { | ||
this.errorMessage = Objects.requireNonNull(errorMessage, "If answer is not given, errorMessage must be provided."); | ||
} | ||
this.interactionId = interactionId; | ||
} | ||
|
||
@Override | ||
|
@@ -57,7 +68,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |
innerToXContent(builder, params); | ||
/* start of ext */ builder.startObject(EXT_SECTION_NAME); | ||
/* start of our stuff */ builder.startObject(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE); | ||
/* body of our stuff */ builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); | ||
if (answer == null) { | ||
builder.field(GENERATIVE_QA_ERROR_FIELD_NAME, this.errorMessage); | ||
} else { | ||
/* body of our stuff */ | ||
builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apology if I missed it, don't we need corresponding parser? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a |
||
if (this.interactionId != null) { | ||
/* interaction id */ | ||
builder.field(INTERACTION_ID_FIELD_NAME, this.interactionId); | ||
} | ||
} | ||
/* end of our stuff */ builder.endObject(); | ||
/* end of ext */ builder.endObject(); | ||
builder.endObject(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,8 +22,13 @@ | |
import lombok.extern.log4j.Log4j2; | ||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionType; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.action.ActionFuture; | ||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.util.CollectionUtils; | ||
import org.opensearch.index.reindex.ScrollableHitSource; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.ml.memory.action.conversation.CreateConversationAction; | ||
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; | ||
|
@@ -37,6 +42,7 @@ | |
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.function.BiFunction; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we use these newly imported libraries? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. |
||
|
||
/** | ||
* An OpenSearch client wrapper for conversational memory related calls. | ||
|
@@ -46,12 +52,13 @@ | |
public class ConversationalMemoryClient { | ||
|
||
private final static Logger logger = LogManager.getLogger(); | ||
private final static long DEFAULT_TIMEOUT_IN_MILLIS = 10_000l; | ||
|
||
private Client client; | ||
|
||
public String createConversation(String name) { | ||
|
||
CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(); | ||
CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS); | ||
log.info("createConversation: id: {}", response.getId()); | ||
return response.getId(); | ||
} | ||
|
@@ -61,7 +68,7 @@ public String createInteraction(String conversationId, String input, String prom | |
Preconditions.checkNotNull(input); | ||
Preconditions.checkNotNull(response); | ||
CreateInteractionResponse res = client.execute(CreateInteractionAction.INSTANCE, | ||
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(); | ||
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS); | ||
log.info("createInteraction: interactionId: {}", res.getId()); | ||
return res.getId(); | ||
} | ||
|
@@ -78,7 +85,7 @@ public List<Interaction> getInteractions(String conversationId, int lastN) { | |
int maxResults = lastN; | ||
do { | ||
GetInteractionsResponse response = | ||
client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(); | ||
client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS); | ||
List<Interaction> list = response.getInteractions(); | ||
if (list != null && !CollectionUtils.isEmpty(list)) { | ||
interactions.addAll(list); | ||
|
@@ -97,6 +104,4 @@ public List<Interaction> getInteractions(String conversationId, int lastN) { | |
|
||
return interactions; | ||
} | ||
|
||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like
DEFAULT
makes more sense here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. Maybe a "max" value doesn't make sense here, although I think 60 seconds is a long time. I will change this to a 30 second default time (timeout).