From 374b8354823980256338e5f7520709142ef2abfa Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 4 Oct 2023 13:58:50 -0700 Subject: [PATCH] =?UTF-8?q?Add=20support=20for=20context=5Fsize=20and=20in?= =?UTF-8?q?clude=20'interaction=5Fid'=20in=20SearchRe=E2=80=A6=20(#1385)?= =?UTF-8?q?=20(#1433)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for context_size and include 'interaction_id' in SearchResponse. [Issue #1372] * Added spotless, removed unused code, added more comments. --------- (cherry picked from commit ae6995a85f77c2cde8e11eb0dd4f22280905c696) Signed-off-by: Austin Lee Signed-off-by: HenryL27 Co-authored-by: Austin Lee --- search-processors/build.gradle | 10 + .../GenerativeQAProcessorConstants.java | 3 + .../GenerativeQARequestProcessor.java | 31 ++- .../GenerativeQAResponseProcessor.java | 215 ++++++++++++++---- .../generative/GenerativeSearchResponse.java | 28 ++- .../client/ConversationalMemoryClient.java | 41 ++-- .../client/MachineLearningInternalClient.java | 41 ++-- .../ext/GenerativeQAParamExtBuilder.java | 11 +- .../generative/ext/GenerativeQAParamUtil.java | 14 +- .../ext/GenerativeQAParameters.java | 96 +++++++- .../generative/llm/ChatCompletionInput.java | 10 +- .../generative/llm/ChatCompletionOutput.java | 30 ++- .../generative/llm/DefaultLlmImpl.java | 47 ++-- .../generative/llm/LlmIOUtil.java | 36 ++- .../generative/llm/ModelLocator.java | 4 +- .../generative/prompt/PromptUtil.java | 72 ++++-- .../GenerativeQAParamUtilTests.java | 10 +- .../GenerativeQARequestProcessorTests.java | 22 +- .../GenerativeQAResponseProcessorTests.java | 153 ++++++++++--- .../GenerativeSearchResponseTests.java | 105 ++++++++- .../ConversationalMemoryClientTests.java | 60 +++-- .../MachineLearningInternalClientTests.java | 38 ++-- .../ext/GenerativeQAParamExtBuilderTests.java | 57 +++-- .../ext/GenerativeQAParamUtilTests.java | 4 +- .../ext/GenerativeQAParametersTests.java | 80 +++++-- .../llm/ChatCompletionInputTests.java | 48 +++- .../llm/ChatCompletionOutputTests.java | 32 ++- .../generative/llm/DefaultLlmImplTests.java | 100 ++++++-- .../generative/llm/LlmIOUtilTests.java | 7 +- .../generative/llm/ModelLocatorTests.java | 4 +- .../generative/prompt/PromptUtilTests.java | 53 +++-- 31 files changed, 1114 insertions(+), 348 deletions(-) diff --git a/search-processors/build.gradle b/search-processors/build.gradle index 3911cf407e..273aeb3257 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -19,6 +19,7 @@ plugins { id 'java' id 'jacoco' id "io.freefair.lombok" + id 'com.diffplug.spotless' version '6.18.0' } repositories { @@ -73,3 +74,12 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification + +spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + + eclipse().configFile rootProject.file('.eclipseformat.xml') + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java index 5b6e8159b8..957e01b302 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java @@ -36,6 +36,9 @@ public class GenerativeQAProcessorConstants { // The field in search results that contain the context to be sent to the LLM. public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list"; + public static final String CONFIG_NAME_SYSTEM_PROMPT = "system_prompt"; + public static final String CONFIG_NAME_USER_INSTRUCTIONS = "user_instructions"; + public static final Setting RAG_PIPELINE_FEATURE_ENABLED = Setting .boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java index b0a741575c..0ca3f0668c 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java @@ -17,6 +17,9 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import java.util.Map; +import java.util.function.BooleanSupplier; + import org.opensearch.action.search.SearchRequest; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.exception.MLException; @@ -24,9 +27,6 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; -import java.util.Map; -import java.util.function.BooleanSupplier; - /** * Defines the request processor for generative QA search pipelines. */ @@ -35,7 +35,13 @@ public class GenerativeQARequestProcessor extends AbstractProcessor implements S private String modelId; private final BooleanSupplier featureFlagSupplier; - protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId, BooleanSupplier supplier) { + protected GenerativeQARequestProcessor( + String tag, + String description, + boolean ignoreFailure, + String modelId, + BooleanSupplier supplier + ) { super(tag, description, ignoreFailure); this.modelId = modelId; this.featureFlagSupplier = supplier; @@ -76,12 +82,17 @@ public SearchRequestProcessor create( PipelineContext pipelineContext ) throws Exception { if (featureFlagSupplier.getAsBoolean()) { - return new GenerativeQARequestProcessor(tag, description, ignoreFailure, - ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, - tag, - config, - GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID - ), + return new GenerativeQARequestProcessor( + tag, + description, + ignoreFailure, + ConfigurationUtils + .readStringProperty( + GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID + ), this.featureFlagSupplier ); } else { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 720d880628..111437ab0f 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -17,11 +17,16 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; -import com.google.gson.Gson; -import com.google.gson.JsonArray; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BooleanSupplier; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; @@ -32,22 +37,20 @@ import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; -import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.BooleanSupplier; +import com.google.gson.JsonArray; -import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; /** * Defines the response processor for generative QA search pipelines. @@ -58,11 +61,16 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10; - // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. + private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30; + + // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. private final String llmModel; private final List contextFields; + private final String systemPrompt; + private final String userInstructions; + @Setter private ConversationalMemoryClient memoryClient; @@ -73,11 +81,23 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements private final BooleanSupplier featureFlagSupplier; - protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, - Llm llm, String llmModel, List contextFields, BooleanSupplier supplier) { + protected GenerativeQAResponseProcessor( + Client client, + String tag, + String description, + boolean ignoreFailure, + Llm llm, + String llmModel, + List contextFields, + String systemPrompt, + String userInstructions, + BooleanSupplier supplier + ) { super(tag, description, ignoreFailure); this.llmModel = llmModel; this.contextFields = contextFields; + this.systemPrompt = systemPrompt; + this.userInstructions = userInstructions; this.llm = llm; this.memoryClient = new ConversationalMemoryClient(client); this.featureFlagSupplier = supplier; @@ -93,22 +113,75 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); + + Integer timeout = params.getTimeout(); + if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) { + timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS; + } + log.info("Timeout for this request: {} seconds.", timeout); + String llmQuestion = params.getLlmQuestion(); String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); + if (llmModel == null) { + throw new IllegalArgumentException("llm_model cannot be null."); + } String conversationId = params.getConversationId(); log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); - List chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW); - List searchResults = getSearchResults(response); - ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults)); - String answer = (String) output.getAnswers().get(0); + Instant start = Instant.now(); + Integer interactionSize = params.getInteractionSize(); + if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) { + interactionSize = DEFAULT_CHAT_HISTORY_WINDOW; + } + log.info("Using interaction size of {}", interactionSize); + List chatHistory = (conversationId == null) + ? Collections.emptyList() + : memoryClient.getInteractions(conversationId, interactionSize); + log.info("Retrieved chat history. ({})", getDuration(start)); + Integer topN = params.getContextSize(); + if (topN == null) { + topN = GenerativeQAParameters.SIZE_NULL_VALUE; + } + List searchResults = getSearchResults(response, topN); + + log.info("system_prompt: {}", systemPrompt); + log.info("user_instructions: {}", userInstructions); + start = Instant.now(); + ChatCompletionOutput output = llm + .doChatCompletion( + LlmIOUtil + .createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout) + ); + log.info("doChatCompletion complete. ({})", getDuration(start)); + + String answer = null; + String errorMessage = null; String interactionId = null; - if (conversationId != null) { - interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer, - GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults)); + if (output.isErrorOccurred()) { + errorMessage = output.getErrors().get(0); + } else { + answer = (String) output.getAnswers().get(0); + + if (conversationId != null) { + start = Instant.now(); + interactionId = memoryClient + .createInteraction( + conversationId, + llmQuestion, + PromptUtil.getPromptTemplate(systemPrompt, userInstructions), + answer, + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + jsonArrayToString(searchResults) + ); + log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); + } } - return insertAnswer(response, answer, interactionId); + return insertAnswer(response, answer, errorMessage, interactionId); + } + + long getDuration(Instant start) { + return Duration.between(start, Instant.now()).toMillis(); } @Override @@ -116,22 +189,36 @@ public String getType() { return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; } - private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) { + private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { // TODO return the interaction id in the response. - return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), - response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); + return new GenerativeSearchResponse( + answer, + errorMessage, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters(), + interactionId + ); } - private List getSearchResults(SearchResponse response) { + private List getSearchResults(SearchResponse response, Integer topN) { List searchResults = new ArrayList<>(); - for (SearchHit hit : response.getHits().getHits()) { - Map docSourceMap = hit.getSourceAsMap(); + SearchHit[] hits = response.getHits().getHits(); + int total = hits.length; + int end = (topN != GenerativeQAParameters.SIZE_NULL_VALUE) ? Math.min(topN, total) : total; + for (int i = 0; i < end; i++) { + Map docSourceMap = hits[i].getSourceAsMap(); for (String contextField : contextFields) { Object context = docSourceMap.get(contextField); if (context == null) { - log.error("Context " + contextField + " not found in search hit " + hit); + log.error("Context " + contextField + " not found in search hit " + hits[i]); // TODO throw a more meaningful error here? throw new RuntimeException(); } @@ -167,36 +254,68 @@ public SearchResponseProcessor create( PipelineContext pipelineContext ) throws Exception { if (this.featureFlagSupplier.getAsBoolean()) { - String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - tag, - config, - GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID - ); - String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - tag, - config, - GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL - ); - List contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - tag, - config, - GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST - ); + String modelId = ConfigurationUtils + .readOptionalStringProperty( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID + ); + String llmModel = ConfigurationUtils + .readOptionalStringProperty( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL + ); + List contextFields = ConfigurationUtils + .readList( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST + ); if (contextFields.isEmpty()) { - throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + throw newConfigurationException( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, tag, GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, "required property can't be empty." ); } - log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields); - return new GenerativeQAResponseProcessor(client, + String systemPrompt = ConfigurationUtils + .readOptionalStringProperty( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT + ); + String userInstructions = ConfigurationUtils + .readOptionalStringProperty( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS + ); + log + .info( + "model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}", + modelId, + llmModel, + contextFields, + systemPrompt, + userInstructions + ); + return new GenerativeQAResponseProcessor( + client, tag, description, ignoreFailure, ModelLocator.getLlm(modelId, client), llmModel, contextFields, + systemPrompt, + userInstructions, featureFlagSupplier ); } else { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java index 2a22902c9a..655010988a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java @@ -17,13 +17,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import java.io.IOException; +import java.util.Objects; + import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - /** * This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section. * @@ -33,11 +34,16 @@ public class GenerativeSearchResponse extends SearchResponse { private static final String EXT_SECTION_NAME = "ext"; private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer"; + private static final String GENERATIVE_QA_ERROR_FIELD_NAME = "error"; + private static final String INTERACTION_ID_FIELD_NAME = "interaction_id"; private final String answer; + private String errorMessage; + private final String interactionId; public GenerativeSearchResponse( String answer, + String errorMessage, SearchResponseSections internalResponse, String scrollId, int totalShards, @@ -45,10 +51,15 @@ public GenerativeSearchResponse( int skippedShards, long tookInMillis, ShardSearchFailure[] shardFailures, - Clusters clusters + Clusters clusters, + String interactionId ) { super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); this.answer = answer; + if (answer == null) { + this.errorMessage = Objects.requireNonNull(errorMessage, "If answer is not given, errorMessage must be provided."); + } + this.interactionId = interactionId; } @Override @@ -57,7 +68,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws innerToXContent(builder, params); /* start of ext */ builder.startObject(EXT_SECTION_NAME); /* start of our stuff */ builder.startObject(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE); - /* body of our stuff */ builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); + if (answer == null) { + builder.field(GENERATIVE_QA_ERROR_FIELD_NAME, this.errorMessage); + } else { + /* body of our stuff */ + builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); + if (this.interactionId != null) { + /* interaction id */ + builder.field(INTERACTION_ID_FIELD_NAME, this.interactionId); + } + } /* end of our stuff */ builder.endObject(); /* end of ext */ builder.endObject(); builder.endObject(); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index 84a32b2368..eca29b3914 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -17,9 +17,9 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.client; -import com.google.common.base.Preconditions; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; +import java.util.ArrayList; +import java.util.List; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -35,8 +35,10 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; -import java.util.ArrayList; -import java.util.List; +import com.google.common.base.Preconditions; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; /** * An OpenSearch client wrapper for conversational memory related calls. @@ -46,22 +48,36 @@ public class ConversationalMemoryClient { private final static Logger logger = LogManager.getLogger(); + private final static long DEFAULT_TIMEOUT_IN_MILLIS = 10_000l; private Client client; public String createConversation(String name) { - CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(); + CreateConversationResponse response = client + .execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)) + .actionGet(DEFAULT_TIMEOUT_IN_MILLIS); log.info("createConversation: id: {}", response.getId()); return response.getId(); } - public String createInteraction(String conversationId, String input, String promptTemplate, String response, String origin, String additionalInfo) { + public String createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo + ) { Preconditions.checkNotNull(conversationId); Preconditions.checkNotNull(input); Preconditions.checkNotNull(response); - CreateInteractionResponse res = client.execute(CreateInteractionAction.INSTANCE, - new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(); + CreateInteractionResponse res = client + .execute( + CreateInteractionAction.INSTANCE, + new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo) + ) + .actionGet(DEFAULT_TIMEOUT_IN_MILLIS); log.info("createInteraction: interactionId: {}", res.getId()); return res.getId(); } @@ -77,8 +93,9 @@ public List getInteractions(String conversationId, int lastN) { boolean allInteractionsFetched = false; int maxResults = lastN; do { - GetInteractionsResponse response = - client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(); + GetInteractionsResponse response = client + .execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)) + .actionGet(DEFAULT_TIMEOUT_IN_MILLIS); List list = response.getInteractions(); if (list != null && !CollectionUtils.isEmpty(list)) { interactions.addAll(list); @@ -97,6 +114,4 @@ public List getInteractions(String conversationId, int lastN) { return interactions; } - - } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java index 265c20a76d..c49bff254e 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java @@ -5,10 +5,8 @@ package org.opensearch.searchpipelines.questionanswering.generative.client; -import com.google.common.annotations.VisibleForTesting; -import lombok.AccessLevel; -import lombok.RequiredArgsConstructor; -import lombok.experimental.FieldDefaults; +import java.util.function.Function; + import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; @@ -20,13 +18,19 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import java.util.function.Function; +import com.google.common.annotations.VisibleForTesting; + +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; /** * An internal facing ML client adapted from org.opensearch.ml.client.MachineLearningNodeClient. */ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor +@Log4j2 public class MachineLearningInternalClient { Client client; @@ -41,11 +45,12 @@ public ActionFuture predict(String modelId, MLInput mlInput) { void predict(String modelId, MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); - MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .modelId(modelId) - .dispatchTask(true) - .build(); + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest + .builder() + .mlInput(mlInput) + .modelId(modelId) + .dispatchTask(true) + .build(); client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); } @@ -60,12 +65,14 @@ private ActionListener getMlPredictionTaskResponseActionListener return actionListener; } - private ActionListener wrapActionListener(final ActionListener listener, final Function recreate) { - ActionListener actionListener = ActionListener.wrap(r-> { - listener.onResponse(recreate.apply(r));; - }, e->{ - listener.onFailure(e); - }); + private ActionListener wrapActionListener( + final ActionListener listener, + final Function recreate + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(recreate.apply(r)); + ; + }, e -> { listener.onFailure(e); }); return actionListener; } @@ -73,7 +80,7 @@ private void validateMLInput(MLInput mlInput, boolean requireInput) { if (mlInput == null) { throw new IllegalArgumentException("ML Input can't be null"); } - if(requireInput && mlInput.getInputDataset() == null) { + if (requireInput && mlInput.getInputDataset() == null) { throw new IllegalArgumentException("input data set can't be null"); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java index 8a6ee8cc65..8cbc719a66 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java @@ -17,17 +17,18 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; +import java.io.IOException; +import java.util.Objects; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; -import java.io.IOException; -import java.util.Objects; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; /** * This is the extension builder for generative QA search pipelines. diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java index 52da6daa02..c0fd558879 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java @@ -17,12 +17,13 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; +import java.util.Optional; + import org.opensearch.action.search.SearchRequest; import org.opensearch.search.SearchExtBuilder; -import java.util.Optional; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; /** * Utility class for extracting generative QA search pipeline parameters from search requests. @@ -33,7 +34,12 @@ public class GenerativeQAParamUtil { public static GenerativeQAParameters getGenerativeQAParameters(SearchRequest request) { GenerativeQAParamExtBuilder builder = null; if (request.source() != null && request.source().ext() != null && !request.source().ext().isEmpty()) { - Optional b = request.source().ext().stream().filter(bldr -> GenerativeQAParamExtBuilder.PARAMETER_NAME.equals(bldr.getWriteableName())).findFirst(); + Optional b = request + .source() + .ext() + .stream() + .filter(bldr -> GenerativeQAParamExtBuilder.PARAMETER_NAME.equals(bldr.getWriteableName())) + .findFirst(); if (b.isPresent()) { builder = (GenerativeQAParamExtBuilder) b.get(); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 04d2b53674..2710b26a57 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -17,42 +17,69 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; -import com.google.common.base.Preconditions; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Objects; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; /** * Defines parameters for generative QA search pipelines. * */ -@AllArgsConstructor + @NoArgsConstructor public class GenerativeQAParameters implements Writeable, ToXContentObject { private static final ObjectParser PARSER; + // Optional parameter; if provided, conversational memory will be used for RAG + // and the current interaction will be saved in the conversation referenced by this id. private static final ParseField CONVERSATION_ID = new ParseField("conversation_id"); + + // Optional parameter; if an LLM model is not set at the search pipeline level, one must be + // provided at the search request level. private static final ParseField LLM_MODEL = new ParseField("llm_model"); + + // Required parameter; this is sent to LLMs as part of the user prompt. + // TODO support question rewriting when chat history is not used (conversation_id is not provided). private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + // Optional parameter; this parameter controls the number of search results ("contexts") to + // include in the user prompt. + private static final ParseField CONTEXT_SIZE = new ParseField("context_size"); + + // Optional parameter; this parameter controls the number of the interactions to include + // in the user prompt. + private static final ParseField INTERACTION_SIZE = new ParseField("interaction_size"); + + // Optional parameter; this parameter controls how long the search pipeline waits for a response + // from a remote inference endpoint before timing out the request. + private static final ParseField TIMEOUT = new ParseField("timeout"); + + public static final int SIZE_NULL_VALUE = -1; + static { PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); + PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); + PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); + PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); } @Setter @@ -67,17 +94,56 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private String llmQuestion; + @Setter + @Getter + private Integer contextSize; + + @Setter + @Getter + private Integer interactionSize; + + @Setter + @Getter + private Integer timeout; + + public GenerativeQAParameters( + String conversationId, + String llmModel, + String llmQuestion, + Integer contextSize, + Integer interactionSize, + Integer timeout + ) { + this.conversationId = conversationId; + this.llmModel = llmModel; + + // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters + // for question rewriting. + Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); + this.llmQuestion = llmQuestion; + this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize; + this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; + this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout; + } + public GenerativeQAParameters(StreamInput input) throws IOException { this.conversationId = input.readOptionalString(); this.llmModel = input.readOptionalString(); this.llmQuestion = input.readString(); + this.contextSize = input.readInt(); + this.interactionSize = input.readInt(); + this.timeout = input.readInt(); } @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return xContentBuilder.field(CONVERSATION_ID.getPreferredName(), this.conversationId) + return xContentBuilder + .field(CONVERSATION_ID.getPreferredName(), this.conversationId) .field(LLM_MODEL.getPreferredName(), this.llmModel) - .field(LLM_QUESTION.getPreferredName(), this.llmQuestion); + .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) + .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) + .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) + .field(TIMEOUT.getPreferredName(), this.timeout); } @Override @@ -87,6 +153,9 @@ public void writeTo(StreamOutput out) throws IOException { Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); out.writeString(llmQuestion); + out.writeInt(contextSize); + out.writeInt(interactionSize); + out.writeInt(timeout); } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { @@ -105,6 +174,9 @@ public boolean equals(Object o) { GenerativeQAParameters other = (GenerativeQAParameters) o; return Objects.equals(this.conversationId, other.getConversationId()) && Objects.equals(this.llmModel, other.getLlmModel()) - && Objects.equals(this.llmQuestion, other.getLlmQuestion()); + && Objects.equals(this.llmQuestion, other.getLlmQuestion()) + && (this.contextSize == other.getContextSize()) + && (this.interactionSize == other.getInteractionSize()) + && (this.timeout == other.getTimeout()); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index faf80b9d7a..85e1173875 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -17,13 +17,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import java.util.List; + +import org.opensearch.ml.common.conversation.Interaction; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.opensearch.ml.common.conversation.Interaction; - -import java.util.List; /** * Input for LLMs via HttpConnector @@ -38,4 +39,7 @@ public class ChatCompletionInput { private String question; private List chatHistory; private List contexts; + private int timeoutInSeconds; + private String systemPrompt; + private String userInstructions; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java index b9bc891a7a..08bda6cefd 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java @@ -17,21 +17,43 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import lombok.AllArgsConstructor; +import java.util.List; + import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import java.util.List; - /** * Output from LLMs via HttpConnector */ @Log4j2 @Getter @Setter -@AllArgsConstructor public class ChatCompletionOutput { private List answers; + private List errors; + + private boolean errorOccurred; + + public ChatCompletionOutput(List answers, List errors) { + + if (answers == null && errors == null) { + throw new IllegalArgumentException("answers and errors can't both be null."); + } + + if (answers == null) { + if (errors.isEmpty()) { + throw new IllegalArgumentException("If answers is not provided, one or more errors must be provided."); + } + this.errorOccurred = true; + } else if (errors == null) { + if (answers.isEmpty()) { + throw new IllegalArgumentException("If errors is not provided, one or more answers must be provided."); + } + } + + this.answers = answers; + this.errors = errors; + } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 58a3cad64c..beef67b9e9 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -17,8 +17,12 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import com.google.common.annotations.VisibleForTesting; -import lombok.extern.log4j.Log4j2; +import static com.google.common.base.Preconditions.checkNotNull; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; import org.opensearch.ml.common.FunctionName; @@ -30,11 +34,9 @@ import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import com.google.common.annotations.VisibleForTesting; -import static com.google.common.base.Preconditions.checkNotNull; +import lombok.extern.log4j.Log4j2; /** * Wrapper for talking to LLMs via OpenSearch HttpConnector. @@ -48,6 +50,7 @@ public class DefaultLlmImpl implements Llm { private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; + private static final String CONNECTOR_OUTPUT_ERROR = "error"; private final String openSearchModelId; @@ -75,26 +78,40 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI Map inputParameters = new HashMap<>(); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts()); + String messages = PromptUtil + .getChatCompletionPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); log.info("Messages to LLM: {}", messages); MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); - ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(); + ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000); // Response from a remote model Map dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); log.info("dataAsMap: {}", dataAsMap.toString()); - // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. + // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); - Map firstChoiceMap = (Map) choices.get(0); - log.info("Choices: {}", firstChoiceMap.toString()); - Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); - log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); - - return new ChatCompletionOutput(List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT))); + List answers = null; + List errors = null; + if (choices == null) { + Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); + errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + } else { + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + } + return new ChatCompletionOutput(answers, errors); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index 5d007420f7..fb95ed63bf 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -17,19 +17,47 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import org.opensearch.ml.common.conversation.Interaction; - import java.util.List; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; + /** * Helper class for creating inputs and outputs for different implementations of LLMs. */ public class LlmIOUtil { - public static ChatCompletionInput createChatCompletionInput(String llmModel, String question, List chatHistory, List contexts) { + public static ChatCompletionInput createChatCompletionInput( + String llmModel, + String question, + List chatHistory, + List contexts, + int timeoutInSeconds + ) { // TODO pick the right subclass based on the modelId. - return new ChatCompletionInput(llmModel, question, chatHistory, contexts); + return createChatCompletionInput( + PromptUtil.DEFAULT_SYSTEM_PROMPT, + null, + llmModel, + question, + chatHistory, + contexts, + timeoutInSeconds + ); + } + + public static ChatCompletionInput createChatCompletionInput( + String systemPrompt, + String userInstructions, + String llmModel, + String question, + List chatHistory, + List contexts, + int timeoutInSeconds + ) { + + return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java index 1b43574374..f9e3d5b811 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java @@ -17,9 +17,10 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import org.opensearch.client.Client; + import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.client.Client; /** * Helper class for wiring LLMs based on the model ID. @@ -30,6 +31,7 @@ public class ModelLocator { public static Llm getLlm(String modelId, Client client) { + return new DefaultLlmImpl(modelId, client); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 10e5a924c6..9c57ffbf0f 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -17,19 +17,22 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.prompt; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.conversation.Interaction; + import com.google.common.annotations.VisibleForTesting; -import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; + import lombok.AccessLevel; import lombok.Getter; import lombok.NoArgsConstructor; -import org.apache.commons.text.StringEscapeUtils; -import org.opensearch.ml.common.conversation.Interaction; - -import java.util.ArrayList; -import java.util.List; /** * A utility class for producing prompts for LLMs. @@ -40,7 +43,7 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class PromptUtil { - public static final String DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE = + public static final String DEFAULT_SYSTEM_PROMPT = "Generate a concise and informative answer in less than 100 words for the given question, taking into context: " + "- An enumerated list of search results" + "- A rephrase of the question that was used to generate the search results" @@ -56,7 +59,17 @@ public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory, List contexts) { - return buildMessageParameter(question, chatHistory, contexts); + return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); + } + + public static String getChatCompletionPrompt( + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts + ) { + return buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); } enum ChatRole { @@ -75,17 +88,32 @@ enum ChatRole { } @VisibleForTesting - static String buildMessageParameter(String question, List chatHistory, List contexts) { + static String buildMessageParameter( + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts + ) { // TODO better prompt template management is needed here. + if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { + systemPrompt = DEFAULT_SYSTEM_PROMPT; + } + JsonArray messageArray = new JsonArray(); - messageArray.add(new Message(ChatRole.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); - for (String result : contexts) { - messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT: " + result).toJson()); + + messageArray.addAll(getPromptTemplateAsJsonArray(systemPrompt, userInstructions)); + for (int i = 0; i < contexts.size(); i++) { + messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)).toJson()); } if (!chatHistory.isEmpty()) { - Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson())); + // The oldest interaction first + // Collections.reverse(chatHistory); + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> messageArray.add(m.toJson())); } messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); @@ -93,14 +121,27 @@ static String buildMessageParameter(String question, List chatHisto return messageArray.toString(); } - private static Gson gson = new Gson(); + public static String getPromptTemplate(String systemPrompt, String userInstructions) { + return getPromptTemplateAsJsonArray(systemPrompt, userInstructions).toString(); + } + + static JsonArray getPromptTemplateAsJsonArray(String systemPrompt, String userInstructions) { + JsonArray messageArray = new JsonArray(); + + if (!Strings.isNullOrEmpty(systemPrompt)) { + messageArray.add(new Message(ChatRole.SYSTEM, systemPrompt).toJson()); + } + if (!Strings.isNullOrEmpty(userInstructions)) { + messageArray.add(new Message(ChatRole.USER, userInstructions).toJson()); + } + return messageArray; + } @Getter static class Messages { @Getter private List messages = new ArrayList<>(); - //private JsonArray jsonArray = new JsonArray(); public Messages(final List messages) { addMessages(messages); @@ -148,6 +189,7 @@ public void setChatRole(ChatRole chatRole) { json.remove(MESSAGE_FIELD_ROLE); json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); } + public void setContent(String content) { this.content = StringEscapeUtils.escapeJson(content); json.remove(MESSAGE_FIELD_CONTENT); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java index cbd5122371..129786e40c 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java @@ -1,5 +1,10 @@ package org.opensearch.searchpipelines.questionanswering.generative; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; + import org.opensearch.action.search.SearchRequest; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.builder.SearchSourceBuilder; @@ -8,11 +13,6 @@ import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class GenerativeQAParamUtilTests extends OpenSearchTestCase { public void testGenerativeQAParametersMissingParams() { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java index a83ccd1767..69301196ac 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java @@ -17,6 +17,12 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BooleanSupplier; + import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchRequest; @@ -25,12 +31,6 @@ import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.test.OpenSearchTestCase; -import java.util.HashMap; -import java.util.Map; -import java.util.function.BooleanSupplier; - -import static org.mockito.Mockito.mock; - public class GenerativeQARequestProcessorTests extends OpenSearchTestCase { private BooleanSupplier alwaysOn = () -> true; @@ -42,8 +42,8 @@ public void testProcessorFactory() throws Exception { Map config = new HashMap<>(); config.put("model_id", "foo"); - SearchRequestProcessor processor = - new GenerativeQARequestProcessor.Factory(alwaysOn).create(null, "tag", "desc", true, config, null); + SearchRequestProcessor processor = new GenerativeQARequestProcessor.Factory(alwaysOn) + .create(null, "tag", "desc", true, config, null); assertTrue(processor instanceof GenerativeQARequestProcessor); } @@ -65,16 +65,16 @@ public void testProcessorFactoryFeatureFlagDisabled() throws Exception { exceptionRule.expectMessage(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); Map config = new HashMap<>(); config.put("model_id", "foo"); - Processor processor = - new GenerativeQARequestProcessor.Factory(()->false).create(null, "tag", "desc", true, config, null); + Processor processor = new GenerativeQARequestProcessor.Factory(() -> false).create(null, "tag", "desc", true, config, null); } // Only to be used for the following test case. private boolean featureFlag001 = false; + public void testProcessorFeatureFlagOffOnOff() throws Exception { Map config = new HashMap<>(); config.put("model_id", "foo"); - Processor.Factory factory = new GenerativeQARequestProcessor.Factory(()->featureFlag001); + Processor.Factory factory = new GenerativeQARequestProcessor.Factory(() -> featureFlag001); boolean firstExceptionThrown = false; try { factory.create(null, "tag", "desc", true, config, null); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index af8b1d9929..b62f0ab38f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -17,6 +17,18 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BooleanSupplier; + import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; @@ -24,8 +36,8 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.client.Client; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.exception.MLException; @@ -33,26 +45,14 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.Processor; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; -import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; import org.opensearch.test.OpenSearchTestCase; -import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BooleanSupplier; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class GenerativeQAResponseProcessorTests extends OpenSearchTestCase { private BooleanSupplier alwaysOn = () -> true; @@ -66,15 +66,28 @@ public void testProcessorFactoryRemoteModel() throws Exception { config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); - GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) - .create(null, "tag", "desc", true, config, null); + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); assertNotNull(processor); } public void testGetType() { Client client = mock(Client.class); Llm llm = mock(Llm.class); - GenerativeQAResponseProcessor processor = new GenerativeQAResponseProcessor(client, null, null, false, llm, "foo", List.of("text"), alwaysOn); + GenerativeQAResponseProcessor processor = new GenerativeQAResponseProcessor( + client, + null, + null, + false, + llm, + "foo", + List.of("text"), + "system_prompt", + "user_instructions", + alwaysOn + ); assertEquals(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, processor.getType()); } @@ -84,12 +97,14 @@ public void testProcessResponseNoSearchHits() throws Exception { config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); - GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) - .create(null, "tag", "desc", true, config, null); + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -127,16 +142,19 @@ public void testProcessResponse() throws Exception { config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); - GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) - .create(null, "tag", "desc", true, config, null); + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -145,7 +163,8 @@ public void testProcessResponse() throws Exception { int numHits = 10; SearchHit[] hitsArray = new SearchHit[numHits]; for (int i = 0; i < numHits; i++) { - XContentBuilder sourceContent = JsonXContent.contentBuilder() + XContentBuilder sourceContent = JsonXContent + .contentBuilder() .startObject() .field("_id", String.valueOf(i)) .field("text", "passage" + i) @@ -173,6 +192,68 @@ public void testProcessResponse() throws Exception { List passages = ((ChatCompletionInput) input).getContexts(); assertEquals("passage0", passages.get(0)); assertEquals("passage1", passages.get(1)); + assertEquals(numHits, passages.size()); + assertTrue(res instanceof GenerativeSearchResponse); + } + + public void testProcessResponseSmallerContextSize() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + int contextSize = 5; + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + SearchResponse res = processor.processResponse(request, response); + verify(llm).doChatCompletion(captor.capture()); + ChatCompletionInput input = captor.getValue(); + assertTrue(input instanceof ChatCompletionInput); + List passages = ((ChatCompletionInput) input).getContexts(); + assertEquals("passage0", passages.get(0)); + assertEquals("passage1", passages.get(1)); + assertEquals(contextSize, passages.size()); assertTrue(res instanceof GenerativeSearchResponse); } @@ -182,16 +263,19 @@ public void testProcessResponseMissingContextField() throws Exception { config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); - GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) - .create(null, "tag", "desc", true, config, null); + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + when(memoryClient.getInteractions(any(), anyInt())) + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); request.source(sourceBuilder); @@ -200,10 +284,11 @@ public void testProcessResponseMissingContextField() throws Exception { int numHits = 10; SearchHit[] hitsArray = new SearchHit[numHits]; for (int i = 0; i < numHits; i++) { - XContentBuilder sourceContent = JsonXContent.contentBuilder() + XContentBuilder sourceContent = JsonXContent + .contentBuilder() .startObject() .field("_id", String.valueOf(i)) - //.field("text", "passage" + i) + // .field("text", "passage" + i) .field("title", "This is the title for document " + i) .endObject(); hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); @@ -241,13 +326,13 @@ public void testProcessorFactoryFeatureDisabled() throws Exception { config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); - Processor processor = - new GenerativeQAResponseProcessor.Factory(client, () -> false) - .create(null, "tag", "desc", true, config, null); + Processor processor = new GenerativeQAResponseProcessor.Factory(client, () -> false) + .create(null, "tag", "desc", true, config, null); } // Use this only for the following test case. private boolean featureEnabled001; + public void testProcessorFeatureOffOnOff() throws Exception { Client client = mock(Client.class); Map config = new HashMap<>(); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java index cead38b0a0..fe52d5df58 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java @@ -17,6 +17,15 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.OutputStream; + +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; @@ -28,20 +37,35 @@ import org.opensearch.search.SearchHits; import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; -import java.io.OutputStream; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class GenerativeSearchResponseTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + public void testToXContent() throws IOException { String answer = "answer"; - SearchResponseSections internal = new SearchResponseSections(new SearchHits(new SearchHit[0], null, 0), null, null, false, false, null, 0); - GenerativeSearchResponse searchResponse = new GenerativeSearchResponse(answer, internal, null, 0, 0, 0, 0, new ShardSearchFailure[0], - SearchResponse.Clusters.EMPTY); + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + GenerativeSearchResponse searchResponse = new GenerativeSearchResponse( + answer, + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY, + "iid" + ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); XContentGenerator generator = mock(XContentGenerator.class); @@ -50,4 +74,65 @@ public void testToXContent() throws IOException { XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(actual); } + + public void testToXContentWithError() throws IOException { + String error = "error"; + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + GenerativeSearchResponse searchResponse = new GenerativeSearchResponse( + null, + error, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY, + "iid" + ); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } + + public void testInputValidation() { + exceptionRule.expect(NullPointerException.class); + exceptionRule.expectMessage("If answer is not given, errorMessage must be provided."); + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + GenerativeSearchResponse searchResponse = new GenerativeSearchResponse( + null, + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY, + "iid" + ); + } + } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java index 67038d93cd..7241ba40ed 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java @@ -17,6 +17,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.client; +import static org.mockito.Mockito.*; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.IntStream; + import org.mockito.ArgumentCaptor; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; @@ -31,14 +39,6 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; import org.opensearch.test.OpenSearchTestCase; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; -import java.util.stream.IntStream; - -import static org.mockito.Mockito.*; - public class ConversationalMemoryClientTests extends OpenSearchTestCase { public void testCreateConversation() { @@ -48,7 +48,7 @@ public void testCreateConversation() { String conversationId = UUID.randomUUID().toString(); CreateConversationResponse response = new CreateConversationResponse(conversationId); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); + when(future.actionGet(anyLong())).thenReturn(response); when(client.execute(eq(CreateConversationAction.INSTANCE), any())).thenReturn(future); String name = "foo"; String actual = memoryClient.createConversation(name); @@ -63,10 +63,14 @@ public void testGetInteractionsNoPagination() { int lastN = 5; String conversationId = UUID.randomUUID().toString(); List interactions = new ArrayList<>(); - IntStream.range(0, lastN).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + IntStream + .range(0, lastN) + .forEach( + i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null)) + ); GetInteractionsResponse response = new GetInteractionsResponse(interactions, lastN, false); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); + when(future.actionGet(anyLong())).thenReturn(response); when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); @@ -85,25 +89,31 @@ public void testGetInteractionsWithPagination() { int lastN = 5; String conversationId = UUID.randomUUID().toString(); List firstPage = new ArrayList<>(); - IntStream.range(0, lastN).forEach(i -> firstPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + IntStream + .range(0, lastN) + .forEach(i -> firstPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); GetInteractionsResponse response1 = new GetInteractionsResponse(firstPage, lastN, true); List secondPage = new ArrayList<>(); - IntStream.range(0, lastN).forEach(i -> secondPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + IntStream + .range(0, lastN) + .forEach( + i -> secondPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null)) + ); GetInteractionsResponse response2 = new GetInteractionsResponse(secondPage, lastN, false); ActionFuture future1 = mock(ActionFuture.class); - when(future1.actionGet()).thenReturn(response1); + when(future1.actionGet(anyLong())).thenReturn(response1); ActionFuture future2 = mock(ActionFuture.class); - when(future2.actionGet()).thenReturn(response2); + when(future2.actionGet(anyLong())).thenReturn(response2); when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future1).thenReturn(future2); ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); - List actual = memoryClient.getInteractions(conversationId, 2*lastN); + List actual = memoryClient.getInteractions(conversationId, 2 * lastN); // Called twice verify(client, times(2)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); List actualRequests = captor.getAllValues(); - assertEquals(2*lastN, actual.size()); + assertEquals(2 * lastN, actual.size()); assertEquals(conversationId, actualRequests.get(0).getConversationId()); - assertEquals(2*lastN, actualRequests.get(0).getMaxResults()); + assertEquals(2 * lastN, actualRequests.get(0).getMaxResults()); assertEquals(0, actualRequests.get(0).getFrom()); assertEquals(lastN, actualRequests.get(1).getFrom()); } @@ -116,10 +126,14 @@ public void testGetInteractionsNoMoreResults() { String conversationId = UUID.randomUUID().toString(); List interactions = new ArrayList<>(); // Return fewer results than requested - IntStream.range(0, found).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + IntStream + .range(0, found) + .forEach( + i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null)) + ); GetInteractionsResponse response = new GetInteractionsResponse(interactions, found, false); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); + when(future.actionGet(anyLong())).thenReturn(response); when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); @@ -138,7 +152,7 @@ public void testAvoidInfiniteLoop() { GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, true); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(future.actionGet(anyLong())).thenReturn(response1).thenReturn(response2); when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); List actual = memoryClient.getInteractions("1", 10); assertTrue(actual.isEmpty()); @@ -152,7 +166,7 @@ public void testNoResults() { GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, false); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(future.actionGet(anyLong())).thenReturn(response1).thenReturn(response2); when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); List actual = memoryClient.getInteractions("1", 10); assertTrue(actual.isEmpty()); @@ -166,7 +180,7 @@ public void testCreateInteraction() { String id = UUID.randomUUID().toString(); CreateInteractionResponse res = new CreateInteractionResponse(id); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(res); + when(future.actionGet(anyLong())).thenReturn(res); when(client.execute(eq(CreateInteractionAction.INSTANCE), any())).thenReturn(future); String actual = memoryClient.createInteraction("cid", "input", "prompt", "answer", "origin", "hits"); assertEquals(id, actual); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java index ce921bac89..6aac45457e 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java @@ -4,6 +4,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.client; +import static org.junit.Assert.assertEquals; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -24,14 +32,6 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import static org.junit.Assert.assertEquals; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.verify; - public class MachineLearningInternalClientTests { @Mock(answer = RETURNS_DEEP_STUBS) NodeClient client; @@ -60,36 +60,30 @@ public void setUp() { public void predict() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLPredictionOutput predictionOutput = MLPredictionOutput.builder() + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() .status("Success") .predictionResult(output) .taskId("taskId") .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(predictionOutput) - .build()); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any()); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); - assertEquals(output, ((MLPredictionOutput)dataFrameArgumentCaptor.getValue()).getPredictionResult()); + assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult()); } @Test public void predict_Exception_WithNullAlgorithm() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("algorithm can't be null"); - MLInput mlInput = MLInput.builder() - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().inputDataset(input).build(); machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); } @@ -97,9 +91,7 @@ public void predict_Exception_WithNullAlgorithm() { public void predict_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("input data set can't be null"); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index b05b52062c..5aeb1e804f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -17,6 +17,15 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.EOFException; +import java.io.IOException; + import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; @@ -26,20 +35,11 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; -import java.io.EOFException; -import java.io.IOException; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { public void testCtor() throws IOException { GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); - GenerativeQAParameters parameters = new GenerativeQAParameters(); + GenerativeQAParameters parameters = new GenerativeQAParameters("conversation_id", "model_id", "question", null, null, null); builder.setParams(parameters); assertEquals(parameters, builder.getParams()); @@ -79,8 +79,8 @@ public int read() throws IOException { } public void testMiscMethods() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); - GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d"); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", null, null, null); GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); builder1.setParams(param1); @@ -105,7 +105,22 @@ public void testParse() throws IOException { } public void testXContentRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentParser parser = createParser(xContentType.xContent(), serialized); + GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); + GenerativeQAParameters parameters = deserialized.getParams(); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize()); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getInteractionSize()); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getTimeout()); + } + + public void testXContentRoundTripAllValues() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); @@ -116,7 +131,21 @@ public void testXContentRoundTrip() throws IOException { } public void testStreamRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + BytesStreamOutput bso = new BytesStreamOutput(); + extBuilder.writeTo(bso); + GenerativeQAParamExtBuilder deserialized = new GenerativeQAParamExtBuilder(bso.bytes().streamInput()); + assertEquals(extBuilder, deserialized); + GenerativeQAParameters parameters = deserialized.getParams(); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize()); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getInteractionSize()); + assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getTimeout()); + } + + public void testStreamRoundTripAllValues() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); BytesStreamOutput bso = new BytesStreamOutput(); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java index c6cf3e9399..8acc46e628 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java @@ -17,12 +17,12 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import java.util.List; + import org.opensearch.action.search.SearchRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; - public class GenerativeQAParamUtilTests extends OpenSearchTestCase { public void testGenerativeQAParametersMissingParams() { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index b2f9d9dc2f..600b1c7a19 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -17,6 +17,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.action.search.SearchRequest; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContent; @@ -25,18 +33,10 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; - public class GenerativeQAParametersTests extends OpenSearchTestCase { public void testGenerativeQAParameters() { - GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question"); + GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); @@ -48,6 +48,7 @@ public void testGenerativeQAParameters() { static class DummyStreamOutput extends StreamOutput { List list = new ArrayList<>(); + List intValues = new ArrayList<>(); @Override public void writeString(String str) { @@ -58,6 +59,15 @@ public List getList() { return list; } + @Override + public void writeInt(int i) { + intValues.add(i); + } + + public List getIntValues() { + return this.intValues; + } + @Override public void writeByte(byte b) throws IOException { @@ -83,11 +93,22 @@ public void reset() throws IOException { } } + public void testWriteTo() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + int contextSize = 1; + int interactionSize = 2; + int timeout = 10; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + contextSize, + interactionSize, + timeout + ); StreamOutput output = new DummyStreamOutput(); parameters.writeTo(output); List actual = ((DummyStreamOutput) output).getList(); @@ -95,26 +116,53 @@ public void testWriteTo() throws IOException { assertEquals(conversationId, actual.get(0)); assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); + List intValues = ((DummyStreamOutput) output).getIntValues(); + assertTrue(contextSize == intValues.get(0)); + assertTrue(interactionSize == intValues.get(1)); + assertTrue(timeout == intValues.get(2)); } public void testMisc() { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); assertNotEquals(parameters, null); assertNotEquals(parameters, "foo"); - assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion)); - assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion)); - assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion)); - assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "")); + assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null)); + assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion, null, null, null)); + assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion, null, null, null)); + // assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "", null)); } public void testToXConent() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c"; - GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion, null, null, null); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + assertNotNull(parameters.toXContent(builder, null)); + } + + public void testToXConentAllOptionalParameters() throws IOException { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + int contextSize = 1; + int interactionSize = 2; + int timeout = 10; + GenerativeQAParameters parameters = new GenerativeQAParameters( + conversationId, + llmModel, + llmQuestion, + contextSize, + interactionSize, + timeout + ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); XContentGenerator generator = mock(XContentGenerator.class); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 925b84b8b1..0e34dd0bf1 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -17,22 +17,32 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import org.opensearch.ml.common.conversation.ConversationalIndexConstants; -import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.test.OpenSearchTestCase; - import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + public class ChatCompletionInputTests extends OpenSearchTestCase { public void testCtor() { String model = "model"; String question = "question"; + String systemPrompt = "you are the best"; + String userInstructions = "walk this way"; - ChatCompletionInput input = new ChatCompletionInput(model, question, Collections.emptyList(), Collections.emptyList()); + ChatCompletionInput input = new ChatCompletionInput( + model, + question, + Collections.emptyList(), + Collections.emptyList(), + 0, + systemPrompt, + userInstructions + ); assertNotNull(input); } @@ -40,17 +50,33 @@ public void testCtor() { public void testGettersSetters() { String model = "model"; String question = "question"; - List history = List.of(Interaction.fromMap("1", - Map.of( - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "convo1", - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "hello"))); + String systemPrompt = "you are the best"; + String userInstructions = "walk this way"; + + List history = List + .of( + Interaction + .fromMap( + "1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + "convo1", + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "hello" + ) + ) + ); List contexts = List.of("result1", "result2"); - ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts); + ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); assertEquals(contexts.get(0), input.getContexts().get(0)); assertEquals(contexts.get(1), input.getContexts().get(1)); + assertEquals(systemPrompt, input.getSystemPrompt()); + assertEquals(userInstructions, input.getUserInstructions()); } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java index c3f6c68688..4768d7489b 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java @@ -17,20 +17,44 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import org.opensearch.test.OpenSearchTestCase; - +import java.util.ArrayList; import java.util.List; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.test.OpenSearchTestCase; + public class ChatCompletionOutputTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + public void testCtor() { - ChatCompletionOutput output = new ChatCompletionOutput(List.of("answer")); + ChatCompletionOutput output = new ChatCompletionOutput(List.of("answer"), null); assertNotNull(output); } public void testGettersSetters() { String answer = "answer"; - ChatCompletionOutput output = new ChatCompletionOutput(List.of(answer)); + ChatCompletionOutput output = new ChatCompletionOutput(List.of(answer), null); assertEquals(answer, (String) output.getAnswers().get(0)); } + + public void testIllegalArgument1() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("answers and errors can't both be null."); + new ChatCompletionOutput(null, null); + } + + public void testIllegalArgument2() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("If answers is not provided, one or more errors must be provided."); + new ChatCompletionOutput(null, new ArrayList<>()); + } + + public void testIllegalArgument3() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("If errors is not provided, one or more answers must be provided."); + new ChatCompletionOutput(new ArrayList<>(), null); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 0aba017245..218bd65ec9 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -17,13 +17,22 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.opensearch.common.action.ActionFuture; import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -37,15 +46,6 @@ import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; import org.opensearch.test.OpenSearchTestCase; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; - public class DefaultLlmImplTests extends OpenSearchTestCase { @Mock @@ -57,14 +57,35 @@ public void testBuildMessageParameter() { List contexts = new ArrayList<>(); contexts.add("context 1"); contexts.add("context 2"); - List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), - Interaction.fromMap("convo1", Map.of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); @@ -81,9 +102,17 @@ public void testChatCompletionApi() throws Exception { ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(mlOutput); + when(future.actionGet(anyLong())).thenReturn(mlOutput); when(mlClient.predict(any(), any())).thenReturn(future); - ChatCompletionInput input = new ChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions" + ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); MLInput mlInput = captor.getValue(); @@ -91,6 +120,37 @@ public void testChatCompletionApi() throws Exception { assertEquals("answer", (String) output.getAnswers().get(0)); } + public void testChatCompletionThrowingError() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + String errorMessage = "throttled"; + Map messageMap = Map.of("message", errorMessage); + Map dataAsMap = Map.of("error", messageMap); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions" + ); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, (String) output.getErrors().get(0)); + } + private boolean isJson(String Json) { try { new JSONObject(Json); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java index 5d8395126b..41d44f18ca 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -17,10 +17,10 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import org.opensearch.test.OpenSearchTestCase; - import java.util.Collections; +import org.opensearch.test.OpenSearchTestCase; + public class LlmIOUtilTests extends OpenSearchTestCase { public void testCtor() { @@ -28,7 +28,8 @@ public void testCtor() { } public void testChatCompletionInput() { - ChatCompletionInput input = LlmIOUtil.createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + ChatCompletionInput input = LlmIOUtil + .createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList(), 0); assertTrue(input instanceof ChatCompletionInput); } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java index dcf3d223fb..a69c010804 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java @@ -17,11 +17,11 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import static org.mockito.Mockito.mock; + import org.opensearch.client.Client; import org.opensearch.test.OpenSearchTestCase; -import static org.mockito.Mockito.mock; - public class ModelLocatorTests extends OpenSearchTestCase { public void testGetRemoteLlm() { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index dd3fed1c9d..583ab17149 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -17,6 +17,12 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.prompt; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; @@ -24,12 +30,6 @@ import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.test.OpenSearchTestCase; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; - public class PromptUtilTests extends OpenSearchTestCase { public void testPromptUtilStaticMethods() { @@ -37,19 +37,42 @@ public void testPromptUtilStaticMethods() { } public void testBuildMessageParameter() { + String systemPrompt = "You are the best."; + String userInstructions = null; String question = "Who am I"; List contexts = new ArrayList<>(); - List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), - Interaction.fromMap("convo1", Map.of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); contexts.add("context 1"); contexts.add("context 2"); - String parameter = PromptUtil.buildMessageParameter(question, chatHistory, contexts); + String parameter = PromptUtil.buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); }