diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index b73dd81330..ae0391b38a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -17,8 +17,15 @@ */ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import static org.opensearch.ml.utils.TestHelper.makeRequest; +import static org.opensearch.ml.utils.TestHelper.toHttpEntity; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; @@ -28,155 +35,165 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.utils.TestHelper; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Locale; -import java.util.Map; -import java.util.Set; - -import static org.opensearch.ml.utils.TestHelper.makeRequest; -import static org.opensearch.ml.utils.TestHelper.toHttpEntity; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); - private static final String OPENAI_CONNECTOR_BLUEPRINT = - "{\n" - + " \"name\": \"OpenAI Chat Connector\",\n" - + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" - + " \"version\": 2,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"model\": \"gpt-3.5-turbo\",\n" - + " \"temperature\": 0\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + OPENAI_KEY + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/chat/completions\",\n" - + " \"headers\": {\n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages}, \\\"temperature\\\": ${parameters.temperature} }\"\n" - + " }\n" - + " ]\n" - + "}"; + private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n" + + " \"name\": \"OpenAI Chat Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 2,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"model\": \"gpt-3.5-turbo\",\n" + + " \"temperature\": 0\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/chat/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; - private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = - "{\n" - + " \"name\": \"Bedrock Connector: claude2\",\n" - + " \"description\": \"The connector to bedrock claude2 model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" + GITHUB_CI_AWS_REGION + "\",\n" - + " \"service_name\": \"bedrock\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" + AWS_ACCESS_KEY_ID + "\",\n" - + " \"secret_key\": \"" + AWS_SECRET_ACCESS_KEY + "\",\n" - + " \"session_token\": \"" + AWS_SESSION_TOKEN + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"url\": \"https://bedrock-runtime." + GITHUB_CI_AWS_REGION + ".amazonaws.com/model/anthropic.claude-v2/invoke\",\n" - + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman: ${parameters.inputs}\\\\n\\\\nAssistant:\\\",\\\"max_tokens_to_sample\\\":300,\\\"temperature\\\":0.5,\\\"top_k\\\":250,\\\"top_p\\\":1,\\\"stop_sequences\\\":[\\\"\\\\\\\\n\\\\\\\\nHuman:\\\"]}\"\n" - + " }\n" - + " ]\n" - + "}"; - private static final String BEDROCK_CONNECTOR_BLUEPRINT2 = - "{\n" - + " \"name\": \"Bedrock Connector: claude2\",\n" - + " \"description\": \"The connector to bedrock claude2 model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" + GITHUB_CI_AWS_REGION + "\",\n" - + " \"service_name\": \"bedrock\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" + AWS_ACCESS_KEY_ID + "\",\n" - + " \"secret_key\": \"" + AWS_SECRET_ACCESS_KEY + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"url\": \"https://bedrock-runtime." + GITHUB_CI_AWS_REGION + ".amazonaws.com/model/anthropic.claude-v2/invoke\",\n" - + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman: ${parameters.inputs}\\\\n\\\\nAssistant:\\\",\\\"max_tokens_to_sample\\\":300,\\\"temperature\\\":0.5,\\\"top_k\\\":250,\\\"top_p\\\":1,\\\"stop_sequences\\\":[\\\"\\\\\\\\n\\\\\\\\nHuman:\\\"]}\"\n" - + " }\n" - + " ]\n" - + "}"; - - private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null ? BEDROCK_CONNECTOR_BLUEPRINT2 : BEDROCK_CONNECTOR_BLUEPRINT1; - private static final String PIPELINE_TEMPLATE = - "{\n" - + " \"response_processors\": [\n" - + " {\n" - + " \"retrieval_augmented_generation\": {\n" - + " \"tag\": \"%s\",\n" - + " \"description\": \"%s\",\n" - + " \"model_id\": \"%s\",\n" - + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_field_list\": [\"%s\"]\n" - + " }\n" - + " }\n" - + " ]\n" - + "}"; - - private static final String BM25_SEARCH_REQUEST_TEMPLATE = - "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"interaction_size\": %d,\n" - + " \"timeout\": %d\n" - + " }\n" - + " }\n" - + "}"; - - private static final String BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE = - "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"conversation_id\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"interaction_size\": %d,\n" - + " \"timeout\": %d\n" - + " }\n" - + " }\n" - + "}"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n" + + " \"name\": \"Bedrock Connector: claude2\",\n" + + " \"description\": \"The connector to bedrock claude2 model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/anthropic.claude-v2/invoke\",\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman: ${parameters.inputs}\\\\n\\\\nAssistant:\\\",\\\"max_tokens_to_sample\\\":300,\\\"temperature\\\":0.5,\\\"top_k\\\":250,\\\"top_p\\\":1,\\\"stop_sequences\\\":[\\\"\\\\\\\\n\\\\\\\\nHuman:\\\"]}\"\n" + + " }\n" + + " ]\n" + + "}"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT2 = "{\n" + + " \"name\": \"Bedrock Connector: claude2\",\n" + + " \"description\": \"The connector to bedrock claude2 model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/anthropic.claude-v2/invoke\",\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman: ${parameters.inputs}\\\\n\\\\nAssistant:\\\",\\\"max_tokens_to_sample\\\":300,\\\"temperature\\\":0.5,\\\"top_k\\\":250,\\\"top_p\\\":1,\\\"stop_sequences\\\":[\\\"\\\\\\\\n\\\\\\\\nHuman:\\\"]}\"\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null + ? BEDROCK_CONNECTOR_BLUEPRINT2 + : BEDROCK_CONNECTOR_BLUEPRINT1; + private static final String PIPELINE_TEMPLATE = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"retrieval_augmented_generation\": {\n" + + " \"tag\": \"%s\",\n" + + " \"description\": \"%s\",\n" + + " \"model_id\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_field_list\": [\"%s\"]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + private static final String BM25_SEARCH_REQUEST_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"interaction_size\": %d,\n" + + " \"timeout\": %d\n" + + " }\n" + + " }\n" + + "}"; + + private static final String BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"conversation_id\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"interaction_size\": %d,\n" + + " \"timeout\": %d\n" + + " }\n" + + " }\n" + + "}"; private static final String OPENAI_MODEL = "gpt-3.5-turbo"; private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude"; @@ -479,35 +496,53 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame String.format(Locale.ROOT, "/_search/pipeline/%s", pipeline), null, toHttpEntity( - String.format( - Locale.ROOT, - PIPELINE_TEMPLATE, - parameters.tag, parameters.description, - parameters.modelId, parameters.systemPrompt, parameters.userInstructions, parameters.context_field - ) + String + .format( + Locale.ROOT, + PIPELINE_TEMPLATE, + parameters.tag, + parameters.description, + parameters.modelId, + parameters.systemPrompt, + parameters.userInstructions, + parameters.context_field + ) ), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); } - private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters) throws Exception { - - String httpEntity = (requestParameters.conversationId == null) ? - String.format( - Locale.ROOT, - BM25_SEARCH_REQUEST_TEMPLATE, - requestParameters.source, requestParameters.source, requestParameters.match, - requestParameters.llmModel, requestParameters.llmQuestion, requestParameters.contextSize, - requestParameters.interactionSize, requestParameters.timeout - ) - : - String.format( - Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE, - requestParameters.source, requestParameters.source, requestParameters.match, - requestParameters.llmModel, requestParameters.llmQuestion, requestParameters.conversationId, requestParameters.contextSize, - requestParameters.interactionSize, requestParameters.timeout - ); + private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters) + throws Exception { + + String httpEntity = (requestParameters.conversationId == null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout + ) + : String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.conversationId, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout + ); return makeRequest( client(), "POST", @@ -524,13 +559,7 @@ private String createConversation(String name) throws Exception { "POST", "/_plugins/_ml/memory/conversation", null, - toHttpEntity( - String.format( - Locale.ROOT, - "{\"name\": \"%s\"}", - name - ) - ), + toHttpEntity(String.format(Locale.ROOT, "{\"name\": \"%s\"}", name)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); return (String) ((Map) parseResponseToMap(response)).get("conversation_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index b65f4ae8f7..e2c8e6dd0e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -64,8 +64,8 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { @Before public void setup() throws IOException, InterruptedException { disableClusterConnectorAccessControl(); - // TODO Do we really need to wait this long? This adds 20s to every test case run. - // Can we instead check the cluster state and move on? + // TODO Do we really need to wait this long? This adds 20s to every test case run. + // Can we instead check the cluster state and move on? Thread.sleep(20000); }