diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 8879306773..5b8f7a3b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -34,7 +34,11 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.InvalidJsonException; import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.PathNotFoundException; import com.networknt.schema.JsonSchema; import com.networknt.schema.JsonSchemaFactory; import com.networknt.schema.SpecVersion; @@ -347,6 +351,94 @@ public static JsonObject getJsonObjectFromString(String jsonString) { return JsonParser.parseString(jsonString).getAsJsonObject(); } + /** + * Checks if a specified JSON path exists within a given JSON object. + * + * This method attempts to read the value at the specified path in the JSON object. + * If the path exists, it returns true. If a PathNotFoundException is thrown, + * indicating that the path does not exist, it returns false. + * + * @param json The JSON object to check. This can be a Map, List, or any object + * that JsonPath can parse. + * @param path The JSON path to check for existence. This should be a valid + * JsonPath expression (e.g., "$.store.book[0].title"). + * @return true if the path exists in the JSON object, false otherwise. + * @throws IllegalArgumentException if the json object is null or if the path is null or empty. + * @throws PathNotFoundException if there's an error in parsing the JSON or the path. + */ + public static boolean pathExists(Object json, String path) { + if (json == null) { + throw new IllegalArgumentException("JSON object cannot be null"); + } + if (path == null || path.isEmpty()) { + throw new IllegalArgumentException("Path cannot be null or empty"); + } + if (!isValidJSONPath(path)) { + throw new IllegalArgumentException("the field path is not a valid json path: " + path); + } + try { + JsonPath.read(json, path); + return true; + } catch (PathNotFoundException e) { + return false; + } catch (InvalidJsonException e) { + throw new IllegalArgumentException("Invalid JSON input", e); + } + } + + /** + * Prepares nested structures in a JSON object based on the given field path. + * + * This method ensures that all intermediate nested objects exist in the JSON object + * for a given field path. If any part of the path doesn't exist, it creates new empty objects + * (HashMaps) for those parts. + * + * @param jsonObject The JSON object to be updated. + * @param fieldPath The full path of the field, potentially including nested structures. + * @return The updated JSON object with necessary nested structures in place. + * + * @throws IllegalArgumentException If there's an issue with JSON parsing or path manipulation. + * + * @implNote This method uses JsonPath for JSON manipulation and StringUtils for path existence checks. + * It handles paths both with and without a leading "$." notation. + * Each non-existent intermediate object in the path is created as an empty HashMap. + * + * @see JsonPath + * @see StringUtils + */ + public static Object prepareNestedStructures(Object jsonObject, String fieldPath) { + + if (fieldPath == null) { + throw new IllegalArgumentException("the field path is null"); + } + if (!isValidJSONPath(fieldPath)) { + throw new IllegalArgumentException("the field path is not a valid json path: " + fieldPath); + } + String path = fieldPath.startsWith("$.") ? fieldPath.substring(2) : fieldPath; + String[] pathParts = path.split("\\."); + Configuration suppressExceptionConfiguration = Configuration + .builder() + .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) + .build(); + StringBuilder currentPath = new StringBuilder("$"); + + for (int i = 0; i < pathParts.length - 1; i++) { + currentPath.append(".").append(pathParts[i]); + if (!StringUtils.pathExists(jsonObject, currentPath.toString())) { + try { + jsonObject = JsonPath + .using(suppressExceptionConfiguration) + .parse(jsonObject) + .set(currentPath.toString(), new java.util.HashMap<>()) + .json(); + } catch (Exception e) { + throw new IllegalArgumentException("Error creating nested structure for path: " + currentPath, e); + } + } + } + return jsonObject; + } + public static void validateSchema(String schemaString, String instanceString) { try { // parse the schema JSON as string diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index d440c44faf..72ec6a05ba 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -6,7 +6,10 @@ package org.opensearch.ml.common.utils; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; @@ -22,6 +25,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import org.apache.commons.text.StringSubstitutor; @@ -29,40 +33,42 @@ import org.junit.Test; import org.opensearch.OpenSearchParseException; +import com.jayway.jsonpath.JsonPath; + public class StringUtilsTest { @Test public void isJson_True() { - Assert.assertTrue(StringUtils.isJson("{}")); - Assert.assertTrue(StringUtils.isJson("[]")); - Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); - Assert.assertTrue(StringUtils.isJson("{\"key\": 123}")); - Assert.assertTrue(StringUtils.isJson("[1, 2, 3]")); - Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); - Assert.assertTrue(StringUtils.isJson("[1, \"a\"]")); - Assert.assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}")); - Assert.assertTrue(StringUtils.isJson("{}")); - Assert.assertTrue(StringUtils.isJson("[]")); - Assert.assertTrue(StringUtils.isJson("[ ]")); - Assert.assertTrue(StringUtils.isJson("[,]")); - Assert.assertTrue(StringUtils.isJson("[abc]")); - Assert.assertTrue(StringUtils.isJson("[\"abc\", 123]")); + assertTrue(StringUtils.isJson("{}")); + assertTrue(StringUtils.isJson("[]")); + assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); + assertTrue(StringUtils.isJson("{\"key\": 123}")); + assertTrue(StringUtils.isJson("[1, 2, 3]")); + assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); + assertTrue(StringUtils.isJson("[1, \"a\"]")); + assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}")); + assertTrue(StringUtils.isJson("{}")); + assertTrue(StringUtils.isJson("[]")); + assertTrue(StringUtils.isJson("[ ]")); + assertTrue(StringUtils.isJson("[,]")); + assertTrue(StringUtils.isJson("[abc]")); + assertTrue(StringUtils.isJson("[\"abc\", 123]")); } @Test public void isJson_False() { - Assert.assertFalse(StringUtils.isJson("{")); - Assert.assertFalse(StringUtils.isJson("[")); - Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}")); - Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); - Assert.assertFalse(StringUtils.isJson("[1, \"a]")); - Assert.assertFalse(StringUtils.isJson("[]\"")); - Assert.assertFalse(StringUtils.isJson("[ ]\"")); - Assert.assertFalse(StringUtils.isJson("[,]\"")); - Assert.assertFalse(StringUtils.isJson("[,\"]")); - Assert.assertFalse(StringUtils.isJson("[]\"123\"")); - Assert.assertFalse(StringUtils.isJson("[abc\"]")); - Assert.assertFalse(StringUtils.isJson("[abc\n123]")); + assertFalse(StringUtils.isJson("{")); + assertFalse(StringUtils.isJson("[")); + assertFalse(StringUtils.isJson("{\"key\": \"value}")); + assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); + assertFalse(StringUtils.isJson("[1, \"a]")); + assertFalse(StringUtils.isJson("[]\"")); + assertFalse(StringUtils.isJson("[ ]\"")); + assertFalse(StringUtils.isJson("[,]\"")); + assertFalse(StringUtils.isJson("[,\"]")); + assertFalse(StringUtils.isJson("[]\"123\"")); + assertFalse(StringUtils.isJson("[abc\"]")); + assertFalse(StringUtils.isJson("[abc\n123]")); } @Test @@ -84,7 +90,7 @@ public void fromJson_NestedMap() { Map response = StringUtils .fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("key") instanceof Map); + assertTrue(response.get("key") instanceof Map); Map nestedMap = (Map) response.get("key"); assertEquals("nested_value", nestedMap.get("nested_key")); List list = (List) nestedMap.get("nested_array"); @@ -97,7 +103,7 @@ public void fromJson_NestedMap() { public void fromJson_SimpleList() { Map response = StringUtils.fromJson("[1, \"a\"]", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("response") instanceof List); + assertTrue(response.get("response") instanceof List); List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); @@ -107,12 +113,12 @@ public void fromJson_SimpleList() { public void fromJson_NestedList() { Map response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("response") instanceof List); + assertTrue(response.get("response") instanceof List); List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); - Assert.assertTrue(list.get(2) instanceof List); - Assert.assertTrue(list.get(3) instanceof Map); + assertTrue(list.get(2) instanceof List); + assertTrue(list.get(3) instanceof Map); } @Test @@ -152,23 +158,23 @@ public void processTextDocs() { List processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]")); assertEquals(3, processedDocs.size()); assertEquals("abc \\n\\n123\\\"4", processedDocs.get(0)); - Assert.assertNull(processedDocs.get(1)); + assertNull(processedDocs.get(1)); assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2)); } @Test public void isEscapeUsed() { - Assert.assertFalse(StringUtils.isEscapeUsed("String escape")); - Assert.assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")")); + assertFalse(StringUtils.isEscapeUsed("String escape")); + assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")")); } @Test public void containsEscapeMethod() { - Assert.assertFalse(StringUtils.containsEscapeMethod("String escape")); - Assert.assertFalse(StringUtils.containsEscapeMethod("String escape()")); - Assert.assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")")); - Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)")); - Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(String input)")); + assertFalse(StringUtils.containsEscapeMethod("String escape")); + assertFalse(StringUtils.containsEscapeMethod("String escape()")); + assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")")); + assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)")); + assertTrue(StringUtils.containsEscapeMethod("String escape(String input)")); } @Test @@ -183,7 +189,7 @@ public void addDefaultMethod_Escape() { String input = "return escape(\"abc\n123\");"; String result = StringUtils.addDefaultMethod(input); Assert.assertNotEquals(input, result); - Assert.assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); + assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); } @Test @@ -464,51 +470,174 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() { @Test public void testisValidJSONPath_InvalidInputs() { - Assert.assertFalse(isValidJSONPath("..bar")); - Assert.assertFalse(isValidJSONPath(".")); - Assert.assertFalse(isValidJSONPath("..")); - Assert.assertFalse(isValidJSONPath("foo.bar.")); - Assert.assertFalse(isValidJSONPath(".foo.bar.")); + assertFalse(isValidJSONPath("..bar")); + assertFalse(isValidJSONPath(".")); + assertFalse(isValidJSONPath("..")); + assertFalse(isValidJSONPath("foo.bar.")); + assertFalse(isValidJSONPath(".foo.bar.")); } @Test public void testisValidJSONPath_NullInput() { - Assert.assertFalse(isValidJSONPath(null)); + assertFalse(isValidJSONPath(null)); } @Test public void testisValidJSONPath_EmptyInput() { - Assert.assertFalse(isValidJSONPath("")); + assertFalse(isValidJSONPath("")); } @Test public void testisValidJSONPath_ValidInputs() { - Assert.assertTrue(isValidJSONPath("foo")); - Assert.assertTrue(isValidJSONPath("foo.bar")); - Assert.assertTrue(isValidJSONPath("foo.bar.baz")); - Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux")); - Assert.assertTrue(isValidJSONPath(".foo")); - Assert.assertTrue(isValidJSONPath("$.foo")); - Assert.assertTrue(isValidJSONPath(".foo.bar")); - Assert.assertTrue(isValidJSONPath("$.foo.bar")); + assertTrue(isValidJSONPath("foo")); + assertTrue(isValidJSONPath("foo.bar")); + assertTrue(isValidJSONPath("foo.bar.baz")); + assertTrue(isValidJSONPath("foo.bar.baz.qux")); + assertTrue(isValidJSONPath(".foo")); + assertTrue(isValidJSONPath("$.foo")); + assertTrue(isValidJSONPath(".foo.bar")); + assertTrue(isValidJSONPath("$.foo.bar")); } @Test public void testisValidJSONPath_WithFilter() { - Assert.assertTrue(isValidJSONPath("$.store['book']")); - Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']")); - Assert.assertTrue(isValidJSONPath("$.store.book[0]")); - Assert.assertTrue(isValidJSONPath("$.store.book[1,2]")); - Assert.assertTrue(isValidJSONPath("$.store.book[-1:] ")); - Assert.assertTrue(isValidJSONPath("$.store.book[0:2]")); - Assert.assertTrue(isValidJSONPath("$.store.book[*]")); - Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]")); - Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]")); - Assert.assertTrue(isValidJSONPath("$..author")); - Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]")); - Assert.assertTrue(isValidJSONPath("$.store.book[0,1]")); - Assert.assertTrue(isValidJSONPath("$['store','warehouse']")); - Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title")); + assertTrue(isValidJSONPath("$.store['book']")); + assertTrue(isValidJSONPath("$['store']['book'][0]['title']")); + assertTrue(isValidJSONPath("$.store.book[0]")); + assertTrue(isValidJSONPath("$.store.book[1,2]")); + assertTrue(isValidJSONPath("$.store.book[-1:] ")); + assertTrue(isValidJSONPath("$.store.book[0:2]")); + assertTrue(isValidJSONPath("$.store.book[*]")); + assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]")); + assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]")); + assertTrue(isValidJSONPath("$..author")); + assertTrue(isValidJSONPath("$..book[?(@.price > 15)]")); + assertTrue(isValidJSONPath("$.store.book[0,1]")); + assertTrue(isValidJSONPath("$['store','warehouse']")); + assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title")); + } + + @Test + public void testPathExists_ExistingPath() { + Object json = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + assertTrue(StringUtils.pathExists(json, "$.a.b")); + } + + @Test + public void testPathExists_NonExistingPath() { + Object json = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + assertFalse(StringUtils.pathExists(json, "$.a.c")); + } + + @Test + public void testPathExists_EmptyObject() { + Object json = JsonPath.parse("{}").json(); + assertFalse(StringUtils.pathExists(json, "$.a")); + } + + @Test + public void testPathExists_NullJson() { + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(null, "$.a")); + } + + @Test + public void testPathExists_NullPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, null)); + } + + @Test + public void testPathExists_EmptyPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, "")); + } + + @Test + public void testPathExists_InvalidPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, "This is not a valid path")); + } + + @Test + public void testPathExists_ArrayElement() { + Object json = JsonPath.parse("{\"a\":[1,2,3]}").json(); + assertTrue(StringUtils.pathExists(json, "$.a[1]")); + assertFalse(StringUtils.pathExists(json, "$.a[3]")); + } + + @Test + public void testPathExists_NestedStructure() { + Object json = JsonPath.parse("{\"a\":{\"b\":{\"c\":{\"d\":42}}}}").json(); + assertTrue(StringUtils.pathExists(json, "$.a.b.c.d")); + assertFalse(StringUtils.pathExists(json, "$.a.b.c.e")); + } + + @Test + public void testPrepareNestedStructures_EmptyObject() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_ExistingStructure() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":{}}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_PartiallyExistingStructure() { + Object jsonObject = JsonPath.parse("{\"a\":{}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c.d"); + assertTrue(JsonPath.read(result, "$.a.b.c") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_WithDollarSign() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "$.a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_SingleLevel() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a"); + assertEquals(jsonObject, result); + } + + @Test + public void testPrepareNestedStructures_ExistingValue() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertEquals(Optional.ofNullable(42), Optional.ofNullable(JsonPath.read(result, "$.a.b"))); + } + + @Test + public void testPrepareNestedStructures_NullInput() { + assertThrows(IllegalArgumentException.class, () -> StringUtils.prepareNestedStructures(null, "a.b.c")); + } + + @Test + public void testPrepareNestedStructures_NullPath() { + Object jsonObject = new HashMap<>(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.prepareNestedStructures(jsonObject, null)); + } + + @Test + public void testPrepareNestedStructures_ComplexPath() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c.d.e.f"); + assertTrue(JsonPath.read(result, "$.a.b.c.d.e") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_MixedExistingAndNew() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":42,\"c\":{}}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.c.d.e"); + assertEquals(Optional.of(42), Optional.of(JsonPath.read(result, "$.a.b"))); + assertTrue(JsonPath.read(result, "$.a.c.d") instanceof Map); } @Test diff --git a/memory/build.gradle b/memory/build.gradle index 940a6b9621..86198c4521 100644 --- a/memory/build.gradle +++ b/memory/build.gradle @@ -41,6 +41,7 @@ dependencies { testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' + testImplementation 'com.jayway.jsonpath:json-path:2.9.0' } test { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 065e0ec371..5067bdc138 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -273,6 +273,7 @@ import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; import org.opensearch.ml.rest.RestMemoryUpdateConversationAction; import org.opensearch.ml.rest.RestMemoryUpdateInteractionAction; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -999,6 +1000,15 @@ public List> getSearchExts() { ) ); + searchExts + .add( + new SearchPlugin.SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + input -> new MLInferenceRequestParametersExtBuilder(input), + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + return searchExts; } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 8782addc82..7f9d0feb96 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.processor; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; @@ -46,6 +47,7 @@ import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.PathNotFoundException; import com.jayway.jsonpath.ReadContext; /** @@ -149,8 +151,7 @@ public void processRequestAsync( } String queryString = request.source().toString(); - - rewriteQueryString(request, queryString, requestListener); + rewriteQueryString(request, queryString, requestListener, requestContext); } catch (Exception e) { if (ignoreFailure) { @@ -164,13 +165,18 @@ public void processRequestAsync( /** * Rewrites the query string based on the input and output mappings and the ML model output. * - * @param request the {@link SearchRequest} to be rewritten - * @param queryString the original query string + * @param request the {@link SearchRequest} to be rewritten + * @param queryString the original query string * @param requestListener the {@link ActionListener} to be notified when the rewriting is complete + * @param requestContext * @throws IOException if an I/O error occurs during the rewriting process */ - private void rewriteQueryString(SearchRequest request, String queryString, ActionListener requestListener) - throws IOException { + private void rewriteQueryString( + SearchRequest request, + String queryString, + ActionListener requestListener, + PipelineProcessingContext requestContext + ) throws IOException { List> processInputMap = inferenceProcessorAttributes.getInputMaps(); List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0; @@ -198,7 +204,8 @@ private void rewriteQueryString(SearchRequest request, String queryString, Actio request, queryString, requestListener, - processOutputMap + processOutputMap, + requestContext ); GroupedActionListener> batchPredictionListener = createBatchPredictionListener( rewriteRequestListener, @@ -219,13 +226,15 @@ private void rewriteQueryString(SearchRequest request, String queryString, Actio * @param queryString the original query string * @param requestListener the {@link ActionListener} to be notified when the query string or query template is updated * @param processOutputMap the list of output mappings + * @param requestContext * @return an {@link ActionListener} that handles the response from the ML model inference */ private ActionListener> createRewriteRequestListener( SearchRequest request, String queryString, ActionListener requestListener, - List> processOutputMap + List> processOutputMap, + PipelineProcessingContext requestContext ) { return new ActionListener<>() { @Override @@ -237,12 +246,10 @@ public void onResponse(Map multipleMLOutputs) { try { if (queryTemplate == null) { Object incomeQueryObject = JsonPath.parse(queryString).read("$"); - updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder( - xContentRegistry, - StringUtils.toJson(incomeQueryObject) - ); + updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput, requestContext); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(xContentRegistry, toJson(incomeQueryObject)); request.source(searchSourceBuilder); + requestListener.onResponse(request); } else { String newQueryString = updateQueryTemplate(queryTemplate, outputMapping, mlOutput); @@ -273,13 +280,52 @@ public void onFailure(Exception e) { } } - private void updateIncomeQueryObject(Object incomeQueryObject, Map outputMapping, MLOutput mlOutput) { + /** + * Updates the income query object with values from the ML output based on the provided output mapping. + * + * This method iterates through the output mapping, retrieves corresponding values from the ML output, + * and updates the income query object accordingly. It also handles nested JSON structures and updates + * the request context with the new values. + * + * @param incomeQueryObject The object representing the income query to be updated. + * @param outputMapping A map containing the mapping between new query fields and model output field names. + * @param mlOutput The MLOutput object containing the results from the machine learning model. + * @param requestContext The context object for the current pipeline processing request. + * + * @throws IllegalArgumentException If a specified JSON path cannot be found in the query string. + * + * @implNote This method uses JsonPath for JSON manipulation and supports both regular and extended (ext) fields. + * For extended fields, it creates nested structures if they don't exist. + * The method also updates the request context with new field values for further processing. + * + * @see JsonPath + * @see PipelineProcessingContext + * @see MLOutput + */ + private void updateIncomeQueryObject( + Object incomeQueryObject, + Map outputMapping, + MLOutput mlOutput, + PipelineProcessingContext requestContext + ) { for (Map.Entry outputMapEntry : outputMapping.entrySet()) { - String newQueryField = outputMapEntry.getKey(); - String modelOutputFieldName = outputMapEntry.getValue(); - Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); - String jsonPathExpression = "$." + newQueryField; - JsonPath.parse(incomeQueryObject).set(jsonPathExpression, modelOutputValue); + String newQueryField = null; + try { + newQueryField = outputMapEntry.getKey(); + String modelOutputFieldName = outputMapEntry.getValue(); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + + if (newQueryField.startsWith("$.ext.") || newQueryField.startsWith("ext.")) { + incomeQueryObject = StringUtils.prepareNestedStructures(incomeQueryObject, newQueryField); + } + + JsonPath.using(suppressExceptionConfiguration).parse(incomeQueryObject).set(newQueryField, modelOutputValue); + + requestContext.setAttribute(newQueryField, modelOutputValue); + + } catch (PathNotFoundException e) { + throw new IllegalArgumentException("can not find path " + newQueryField + "in query string"); + } } } @@ -300,12 +346,12 @@ private String updateQueryTemplate(String queryTemplate, Map out /** * Creates a {@link GroupedActionListener} that collects the responses from multiple ML model inferences. * - * @param rewriteRequestListner the {@link ActionListener} to be notified when all ML model inferences are complete + * @param rewriteRequestListener the {@link ActionListener} to be notified when all ML model inferences are complete * @param inputMapSize the number of input mappings * @return a {@link GroupedActionListener} that handles the responses from multiple ML model inferences */ private GroupedActionListener> createBatchPredictionListener( - ActionListener> rewriteRequestListner, + ActionListener> rewriteRequestListener, int inputMapSize ) { return new GroupedActionListener<>(new ActionListener<>() { @@ -315,13 +361,13 @@ public void onResponse(Collection> mlOutputMapCollection) for (Map mlOutputMap : mlOutputMapCollection) { mlOutputMaps.putAll(mlOutputMap); } - rewriteRequestListner.onResponse(mlOutputMaps); + rewriteRequestListener.onResponse(mlOutputMaps); } @Override public void onFailure(Exception e) { logger.error("Prediction Failed:", e); - rewriteRequestListner.onFailure(e); + rewriteRequestListener.onFailure(e); } }, Math.max(inputMapSize, 1)); } @@ -358,11 +404,12 @@ private boolean validateQueryFieldInQueryString( for (Map outputMap : processOutputMap) { for (Map.Entry entry : outputMap.entrySet()) { String queryField = entry.getKey(); - Object pathData = jsonData.read(queryField); - if (pathData == null) { - throw new IllegalArgumentException( - "cannot find field: " + queryField + " in query string: " + jsonData.jsonString() - ); + // output writing to search extension can be new field + if (!queryField.startsWith("ext.") && !queryField.startsWith("$.ext.")) { + Object pathData = jsonData.read(queryField); + if (pathData == null) { + throw new IllegalArgumentException(); + } } } } @@ -402,7 +449,7 @@ private void processPredictions( // model field as key, query field name as value String modelInputFieldName = entry.getKey(); String queryFieldName = entry.getValue(); - String queryFieldValue = StringUtils.toJson(JsonPath.parse(newQuery).read(queryFieldName)); + String queryFieldValue = toJson(JsonPath.parse(newQuery).read(queryFieldName)); modelParameters.put(modelInputFieldName, queryFieldValue); } } @@ -446,14 +493,20 @@ public void onFailure(Exception e) { /** * Creates a SearchSourceBuilder instance from the given query string. * + * This method parses the provided query string, substitutes parameters, and constructs + * a SearchSourceBuilder object. It handles JSON content and performs variable substitution + * using a StringSubstitutor. + * * @param xContentRegistry the XContentRegistry instance to be used for parsing - * @param queryString the query template string to be parsed + * @param queryString the query template string to be parsed * @return a SearchSourceBuilder instance created from the query string - * @throws IOException if an I/O error occurs during parsing + * @throws IOException if an I/O error occurs during parsing or content creation */ private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry xContentRegistry, String queryString) throws IOException { + // MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + // SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().ext(List.of(mlInferenceExtBuilder)); XContentParser queryParser = XContentType.JSON .xContent() @@ -461,7 +514,9 @@ private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry ensureExpectedToken(XContentParser.Token.START_OBJECT, queryParser.nextToken(), queryParser); searchSourceBuilder.parseXContent(queryParser); + return searchSourceBuilder; + } /** diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index d32308d2ef..8a25b985f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -105,11 +105,11 @@ default ActionRequest getMLModelInferenceRequest( * Retrieves the model output value from the given ModelTensorOutput for the specified modelOutputFieldName. * It handles cases where the output contains a single tensor or multiple tensors. * - * @param modelTensorOutput the ModelTensorOutput containing the model output - * @param modelOutputFieldName the name of the field in the model output to retrieve the value for - * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception + * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param modelOutputFieldName the name of the field in the model output to retrieve the value for + * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception * @return the model output value as an Object - * @throws RuntimeException if there is an error retrieving the model output value + * @throws RuntimeException if there is an error retrieving the model output value */ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String modelOutputFieldName, boolean ignoreMissing) { Object modelOutputValue; @@ -298,6 +298,7 @@ default boolean hasField(Object json, String path) { * Writes a new dot path for a nested object within the given JSON object. * This method is useful when dealing with arrays or nested objects in the JSON structure. * for example foo.*.bar.*.quk to be [foo.0.bar.0.quk, foo.0.bar.1.quk..] + * * @param json the JSON object containing the nested object * @param dotPath the dot path representing the location of the nested object * @return a list of dot paths representing the new locations of the nested object diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java new file mode 100644 index 0000000000..042b3d915e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.ml.searchext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@NoArgsConstructor +public class MLInferenceRequestParameters implements Writeable, ToXContentObject { + static final String ML_INFERENCE_FIELD = "ml_inference"; + + @Setter + @Getter + private Map params; + + public MLInferenceRequestParameters(Map params) { + this.params = params; + + } + + public MLInferenceRequestParameters(StreamInput input) throws IOException { + this.params = input.readMap(); + } + + /** + * Write this into the {@linkplain StreamOutput}. + * + * @param out + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(this.params); + } + + public static MLInferenceRequestParameters parse(XContentParser parser) throws IOException { + return new MLInferenceRequestParameters(parser.map()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(ML_INFERENCE_FIELD); + return builder.map(this.params); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + MLInferenceRequestParameters config = (MLInferenceRequestParameters) o; + + return params.equals(config.getParams()); + } + + @Override + public int hashCode() { + return Objects.hash(params); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java new file mode 100644 index 0000000000..c8c9ffd8aa --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import static org.opensearch.ml.searchext.MLInferenceRequestParameters.ML_INFERENCE_FIELD; + +import java.io.IOException; +import java.util.Objects; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +public class MLInferenceRequestParametersExtBuilder extends SearchExtBuilder { + private static final Logger logger = LogManager.getLogger(MLInferenceRequestParametersExtBuilder.class); + public static final String NAME = ML_INFERENCE_FIELD; + private MLInferenceRequestParameters requestParameters; + + public MLInferenceRequestParametersExtBuilder() {} + + public MLInferenceRequestParametersExtBuilder(StreamInput input) throws IOException { + this.requestParameters = new MLInferenceRequestParameters(input); + } + + public MLInferenceRequestParameters getRequestParameters() { + return requestParameters; + } + + public void setRequestParameters(MLInferenceRequestParameters requestParameters) { + this.requestParameters = requestParameters; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.requestParameters); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof MLInferenceRequestParametersExtBuilder)) { + return false; + } + MLInferenceRequestParametersExtBuilder o = (MLInferenceRequestParametersExtBuilder) obj; + return this.requestParameters.equals(o.requestParameters); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + requestParameters.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(requestParameters); + } + + public static MLInferenceRequestParametersExtBuilder parse(XContentParser parser) throws IOException { + + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + MLInferenceRequestParameters requestParameters = MLInferenceRequestParameters.parse(parser); + extBuilder.setRequestParameters(requestParameters); + return extBuilder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java new file mode 100644 index 0000000000..1073a55b40 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import java.util.List; +import java.util.stream.Collectors; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; + +public class MLInferenceRequestParametersUtil { + + public static MLInferenceRequestParameters getMLInferenceRequestParameters(SearchRequest searchRequest) { + MLInferenceRequestParametersExtBuilder mLInferenceRequestParametersExtBuilder = null; + if (searchRequest.source() != null && searchRequest.source().ext() != null && !searchRequest.source().ext().isEmpty()) { + List extBuilders = searchRequest + .source() + .ext() + .stream() + .filter(extBuilder -> MLInferenceRequestParametersExtBuilder.NAME.equals(extBuilder.getWriteableName())) + .collect(Collectors.toList()); + + if (!extBuilders.isEmpty()) { + mLInferenceRequestParametersExtBuilder = (MLInferenceRequestParametersExtBuilder) extBuilders.get(0); + } + } + MLInferenceRequestParameters mlInferenceRequestParameters = null; + if (mLInferenceRequestParametersExtBuilder != null) { + mlInferenceRequestParameters = mLInferenceRequestParametersExtBuilder.getRequestParameters(); + } + return mlInferenceRequestParameters; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index 6da9cb406a..0ff85c939a 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -40,6 +40,7 @@ import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -66,9 +67,11 @@ public void setUp() { @Test public void testGetSearchExts() { List> searchExts = plugin.getSearchExts(); - assertEquals(1, searchExts.size()); - SearchPlugin.SearchExtSpec spec = searchExts.get(0); - assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); + assertEquals(2, searchExts.size()); + SearchPlugin.SearchExtSpec spec1 = searchExts.get(0); + assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec1.getName().getPreferredName()); + SearchPlugin.SearchExtSpec spec2 = searchExts.get(1); + assertEquals(MLInferenceRequestParametersExtBuilder.NAME, spec2.getName().getPreferredName()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 353d2be1a3..771e882310 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -35,6 +35,9 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.ml.searchext.MLInferenceRequestParameters; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -48,15 +51,27 @@ public class MLInferenceSearchRequestProcessorTests extends AbstractBuilderTestC @Mock private PipelineProcessingContext requestContext; - static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() - ); + static public NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY; private static final String PROCESSOR_TAG = "inference"; private static final String DESCRIPTION = "inference_test"; @Before public void setup() { MockitoAnnotations.openMocks(this); + + TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List + .of( + new SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + MLInferenceRequestParametersExtBuilder::new, + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + } + })).getNamedXContents()); } /** @@ -183,7 +198,7 @@ public void onResponse(SearchRequest newSearchRequest) { @Override public void onFailure(Exception e) { - throw new RuntimeException("Failed in executing processRequestAsync."); + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); } }; @@ -240,7 +255,7 @@ public void onResponse(SearchRequest newSearchRequest) { @Override public void onFailure(Exception e) { - throw new RuntimeException("Failed in executing processRequestAsync."); + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); } }; @@ -1021,6 +1036,139 @@ public void onFailure(Exception ex) { } + /** + * Tests the successful rewriting of a single string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryWriteToExtensionSuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "$.ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "eng")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + Map llmResponse = new HashMap<>(); + llmResponse.put("llm_response", "eng"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmResponse); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.source().toString(), newSearchRequest.source().toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a single string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryReadAndWriteToExtensionSuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "ext.ml_inference.question"; + String newQueryField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "eng")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map llmQuestion = new HashMap<>(); + llmQuestion.put("question", "what language is this text in?"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmQuestion); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + + SearchRequest request = new SearchRequest().source(source); + + // expecting new request with ml inference search extensions + Map params = new HashMap<>(); + params.put("question", "what language is this text in?"); + params.put("llm_response", "eng"); + MLInferenceRequestParameters expectedRequestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder expectedMlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + expectedMlInferenceExtBuilder.setRequestParameters(expectedRequestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(expectedMlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + /** * Helper method to create an instance of the MLInferenceSearchRequestProcessor with the specified parameters. * diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index dedae5f1bd..1ad9eee136 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -66,6 +66,8 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.ml.searchext.MLInferenceRequestParameters; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; @@ -518,7 +520,84 @@ public void onFailure(Exception e) { toJson(inputDataSet.getParameters()), "{\"text_docs\":\"[\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]\"}" ); + } + + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * read the query text into model config + * with query extensions + * @throws Exception if an error occurs during the test + */ + @Test + public void testProcessResponseSuccessReadQueryTextFromExt() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + input.put("query_text", "_request.ext.ml_inference.query_text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "text_similarity", + false, + false, + false, + "{ \"query_text\": \"${input_map.query_text}\", \"text_docs\":${input_map.text_docs}}", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequestWithExtension("query_text", "query.term.text.value"); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + } + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); } /** @@ -3931,6 +4010,22 @@ private static SearchRequest getSearchRequest() { return request; } + private static SearchRequest getSearchRequestWithExtension(String queryText, String queryPath) { + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map params = new HashMap<>(); + params.put(queryText, queryPath); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + extBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(extBuilder)); + ; + SearchRequest request = new SearchRequest().source(source); + + return request; + } + private static Map generateInferenceResult(String response) { Map inferenceResult = new HashMap<>(); List> inferenceResults = new ArrayList<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java new file mode 100644 index 0000000000..bf705c6649 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java @@ -0,0 +1,306 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.searchext.MLInferenceRequestParameters.ML_INFERENCE_FIELD; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.SearchModule; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInferenceRequestParametersExtBuilderTests extends OpenSearchTestCase { + + public NamedXContentRegistry xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List + .of( + new SearchPlugin.SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + MLInferenceRequestParametersExtBuilder::new, + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + } + })).getNamedXContents()); + + public void testParse() throws IOException { + String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}"; + + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, requiredJsonStr); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInferenceRequestParametersExtBuilder builder = MLInferenceRequestParametersExtBuilder.parse(parser); + assertNotNull(builder); + assertNotNull(builder.getRequestParameters()); + MLInferenceRequestParameters params = builder.getRequestParameters(); + Assert.assertEquals("this is test llm question", params.getParams().get("llm_question")); + } + + @Test + public void testMultipleParameters() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + params.put("model_id", "model1"); + params.put("max_tokens", 100); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(requestParameters); + + BytesStreamOutput out = new BytesStreamOutput(); + builder.writeTo(out); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder(out.bytes().streamInput()); + assertEquals(builder, deserialized); + assertEquals(params, deserialized.getRequestParameters().getParams()); + } + + @Test + public void testParseWithEmptyObject() throws IOException { + String emptyJsonStr = "{}"; + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, emptyJsonStr); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInferenceRequestParametersExtBuilder builder = MLInferenceRequestParametersExtBuilder.parse(parser); + assertNotNull(builder); + assertNotNull(builder.getRequestParameters()); + assertTrue(builder.getRequestParameters().getParams().isEmpty()); + } + + @Test + public void testWriteableName() throws IOException { + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + assertEquals(builder.getWriteableName(), ML_INFERENCE_FIELD); + } + + @Test + public void testEquals() throws IOException { + MLInferenceRequestParametersExtBuilder MlInferenceParamBuilder = new MLInferenceRequestParametersExtBuilder(); + GenerativeQAParamExtBuilder qaParamExtBuilder = new GenerativeQAParamExtBuilder(); + assertEquals(MlInferenceParamBuilder.equals(qaParamExtBuilder), false); + assertEquals(MlInferenceParamBuilder.equals(null), false); + } + + @Test + public void testMLInferenceRequestParametersEqualsWithNull() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(params); + assertFalse(parameters.equals(null)); + } + + @Test + public void testMLInferenceRequestParametersEqualsWithDifferentClass() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(params); + assertFalse(parameters.equals("not a MLInferenceRequestParameters object")); + } + + @Test + public void testMLInferenceRequestParametersToXContentWithEmptyParams() throws IOException { + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(new HashMap<>()); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + parameters.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + assertEquals("{\"ml_inference\":{}}", builder.toString()); + } + + @Test + public void testMLInferenceRequestParametersExtBuilderToXContentWithEmptyParams() throws IOException { + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(new HashMap<>()); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(parameters); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + builder.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + assertEquals("{\"ml_inference\":{}}", xContentBuilder.toString()); + } + + @Test + public void testMLInferenceRequestParametersStreamRoundTripWithNullParams() throws IOException { + MLInferenceRequestParameters original = new MLInferenceRequestParameters(); + original.setParams(null); + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + MLInferenceRequestParameters deserialized = new MLInferenceRequestParameters(out.bytes().streamInput()); + assertNull(deserialized.getParams()); + } + + @Test + public void testMLInferenceRequestParametersExtBuilderStreamRoundTripWithNullParams() throws IOException { + MLInferenceRequestParametersExtBuilder original = new MLInferenceRequestParametersExtBuilder(); + original.setRequestParameters(null); + BytesStreamOutput out = new BytesStreamOutput(); + assertThrows(NullPointerException.class, () -> original.writeTo(out)); + } + + @Test + public void testEqualsAndHashCode() { + Map params1 = new HashMap<>(); + params1.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters1 = new MLInferenceRequestParameters(params1); + MLInferenceRequestParametersExtBuilder builder1 = new MLInferenceRequestParametersExtBuilder(); + builder1.setRequestParameters(requestParameters1); + + Map params2 = new HashMap<>(); + params2.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters2 = new MLInferenceRequestParameters(params2); + MLInferenceRequestParametersExtBuilder builder2 = new MLInferenceRequestParametersExtBuilder(); + builder2.setRequestParameters(requestParameters2); + + assertEquals(builder1, builder2); + assertEquals(builder1.hashCode(), builder2.hashCode()); + + Map params3 = new HashMap<>(); + params3.put("query_text", "bar"); + MLInferenceRequestParameters requestParameters3 = new MLInferenceRequestParameters(params3); + MLInferenceRequestParametersExtBuilder builder3 = new MLInferenceRequestParametersExtBuilder(); + builder3.setRequestParameters(requestParameters3); + + assertNotEquals(builder1, builder3); + assertNotEquals(builder1.hashCode(), builder3.hashCode()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(mlInferenceExtBuilder, xContentType, true); + + XContentParser parser = createParser(xContentType.xContent(), serialized); + + MLInferenceRequestParametersExtBuilder deserialized = MLInferenceRequestParametersExtBuilder.parse(parser); + + assertEquals(deserialized.getRequestParameters().getParams().get(ML_INFERENCE_FIELD), params); + + } + + @Test + public void testStreamRoundTrip() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(); + requestParameters.setParams(params); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlInferenceExtBuilder.writeTo(bytesStreamOutput); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder( + bytesStreamOutput.bytes().streamInput() + ); + assertEquals(mlInferenceExtBuilder, deserialized); + } + + @Test + public void testNullRequestParameters() throws IOException { + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + assertNull(builder.getRequestParameters()); + + BytesStreamOutput out = new BytesStreamOutput(); + + // Expect NullPointerException when writing null requestParameters + assertThrows(NullPointerException.class, () -> builder.writeTo(out)); + + // Test that we can still create a new builder with null requestParameters + MLInferenceRequestParametersExtBuilder newBuilder = new MLInferenceRequestParametersExtBuilder(); + assertNull(newBuilder.getRequestParameters()); + } + + @Test + public void testEmptyRequestParameters() throws IOException { + MLInferenceRequestParameters emptyParams = new MLInferenceRequestParameters(new HashMap<>()); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(emptyParams); + + BytesStreamOutput out = new BytesStreamOutput(); + builder.writeTo(out); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder(out.bytes().streamInput()); + assertNotNull(deserialized.getRequestParameters()); + assertTrue(deserialized.getRequestParameters().getParams().isEmpty()); + } + + @Test + public void testToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(requestParameters); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + builder.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + + String expected = "{\"ml_inference\":{\"query_text\":\"foo\"}}"; + assertEquals(expected, xContentBuilder.toString()); + } + + @Test + public void testMLInferenceRequestParametersEqualsAndHashCode() { + Map params1 = new HashMap<>(); + params1.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters1 = new MLInferenceRequestParameters(params1); + + Map params2 = new HashMap<>(); + params2.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters2 = new MLInferenceRequestParameters(params2); + + Map params3 = new HashMap<>(); + params3.put("query_text", "bar"); + MLInferenceRequestParameters requestParameters3 = new MLInferenceRequestParameters(params3); + + assertEquals(requestParameters1, requestParameters2); + assertEquals(requestParameters1.hashCode(), requestParameters2.hashCode()); + assertNotEquals(requestParameters1, requestParameters3); + assertNotEquals(requestParameters1.hashCode(), requestParameters3.hashCode()); + } + + @Test + public void testMLInferenceRequestParametersToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + requestParameters.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + + String expected = "{\"ml_inference\":{\"query_text\":\"foo\"}}"; + assertEquals(expected, xContentBuilder.toString()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java new file mode 100644 index 0000000000..ea2dc55d26 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ +package org.opensearch.ml.searchext; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class MLInferenceRequestParametersUtilTests { + @Test + public void testExtractParameters() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters expected = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLInferenceRequestParameters actual = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertEquals(expected, actual); + } + + @Test + public void testExtractParametersWithNullSource() { + SearchRequest request = new SearchRequest(); + MLInferenceRequestParameters result = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertNull(result); + } + + @Test + public void testExtractParametersWithEmptyExt() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + MLInferenceRequestParameters result = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertNull(result); + } + +}