From 9ed0040e2e0edef67f51ef3d9a217a0c1f1b5206 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 21 Oct 2024 13:41:16 -0700 Subject: [PATCH] add config field in MLToolSpec for static parameters (#2977) * add config field in MLToolSpec for static parameters Signed-off-by: Jing Zhang * add version control Signed-off-by: Jing Zhang * address comments I Signed-off-by: Jing Zhang * address commits II Signed-off-by: Jing Zhang * address comments III Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../ml/common/agent/MLToolSpec.java | 35 +++- .../ml/common/agent/MLAgentTest.java | 27 ++- .../ml/common/agent/MLToolSpecTest.java | 173 +++++++++++++++++- .../agent/MLAgentGetResponseTest.java | 2 +- .../engine/algorithms/agent/AgentUtils.java | 13 ++ .../MLConversationalFlowAgentRunner.java | 6 + .../algorithms/agent/MLFlowAgentRunner.java | 5 + .../agent/MLChatAgentRunnerTest.java | 79 ++++++++ .../agent/MLFlowAgentRunnerTest.java | 20 ++ .../agents/GetAgentTransportActionTests.java | 2 +- 11 files changed, 341 insertions(+), 22 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 2dc4790bb2..3adaa8ca2e 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -581,4 +581,5 @@ public class CommonValue { public static final Version VERSION_2_15_0 = Version.fromString("2.15.0"); public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); + public static final Version VERSION_2_18_0 = Version.fromString("2.18.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 98f7e1f33c..c144d5cda9 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -11,11 +11,13 @@ import java.io.IOException; import java.util.Map; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -24,20 +26,31 @@ @EqualsAndHashCode @Getter public class MLToolSpec implements ToXContentObject { + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_18_0; + public static final String TOOL_TYPE_FIELD = "type"; public static final String TOOL_NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; public static final String PARAMETERS_FIELD = "parameters"; public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + public static final String CONFIG_FIELD = "config"; private String type; private String name; private String description; private Map parameters; private boolean includeOutputInAgentResponse; + private Map configMap; @Builder(toBuilder = true) - public MLToolSpec(String type, String name, String description, Map parameters, boolean includeOutputInAgentResponse) { + public MLToolSpec( + String type, + String name, + String description, + Map parameters, + boolean includeOutputInAgentResponse, + Map configMap + ) { if (type == null) { throw new IllegalArgumentException("tool type is null"); } @@ -46,6 +59,7 @@ public MLToolSpec(String type, String name, String description, Map parameters = null; boolean includeOutputInAgentResponse = false; + Map configMap = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -119,6 +148,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: includeOutputInAgentResponse = parser.booleanValue(); break; + case CONFIG_FIELD: + configMap = getParameterMap(parser.map()); + break; default: parser.skipChildren(); break; @@ -131,6 +163,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { .description(description) .parameters(parameters) .includeOutputInAgentResponse(includeOutputInAgentResponse) + .configMap(configMap) .build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index b83758fc23..c72da18a30 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -46,7 +46,7 @@ public void constructor_NullName() { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -66,7 +66,7 @@ public void constructor_NullType() { null, "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -86,7 +86,7 @@ public void constructor_NullLLMSpec() { MLAgentType.CONVERSATIONAL.name(), "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -100,7 +100,14 @@ public void constructor_NullLLMSpec() { public void constructor_DuplicateTool() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); - MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); + MLToolSpec mlToolSpec = new MLToolSpec( + "test_tool_type", + "test_tool_name", + "test", + Collections.emptyMap(), + false, + Collections.emptyMap() + ); MLAgent agent = new MLAgent( "test_name", MLAgentType.CONVERSATIONAL.name(), @@ -123,7 +130,7 @@ public void writeTo() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -150,7 +157,7 @@ public void writeTo_NullLLM() throws IOException { "FLOW", "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -194,7 +201,7 @@ public void writeTo_NullParameters() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -216,7 +223,7 @@ public void writeTo_NullMemory() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), null, Instant.EPOCH, @@ -238,7 +245,7 @@ public void toXContent() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), + List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -294,7 +301,7 @@ public void fromStream() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java index 3d4d9a2ce5..ecbf4d0ba1 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -22,7 +22,14 @@ public class MLToolSpecTest { @Test public void writeTo() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -32,18 +39,98 @@ public void writeTo() throws IOException { Assert.assertEquals(spec.getParameters(), spec1.getParameters()); Assert.assertEquals(spec.getDescription(), spec1.getDescription()); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void writeToEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void writeToNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertNull(spec1.getConfigMap()); } @Test public void toXContent() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); Assert .assertEquals( - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}", + content + ); + } + + @Test + public void toXContentEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert + .assertEquals( + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}", + content + ); + } + + @Test + public void toXContentNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert + .assertEquals( + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}", content ); } @@ -51,7 +138,7 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { String jsonStr = - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}"; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -62,16 +149,83 @@ public void parse() throws IOException { parser.nextToken(); MLToolSpec spec = MLToolSpec.parse(parser); - Assert.assertEquals(spec.getType(), "test"); - Assert.assertEquals(spec.getName(), "test"); - Assert.assertEquals(spec.getDescription(), "test"); - Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.getType(), "test_type"); + Assert.assertEquals(spec.getName(), "test_name"); + Assert.assertEquals(spec.getDescription(), "test_desc"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(spec.getConfigMap(), Map.of("configKey", "configValue")); + } + + @Test + public void parseEmptyConfigMap() throws IOException { + String jsonStr = + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLToolSpec spec = MLToolSpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test_type"); + Assert.assertEquals(spec.getName(), "test_name"); + Assert.assertEquals(spec.getDescription(), "test_desc"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(spec.getConfigMap(), null); } @Test public void fromStream() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void fromStreamEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void fromStreamNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); @@ -81,5 +235,6 @@ public void fromStream() throws IOException { Assert.assertEquals(spec.getParameters(), spec1.getParameters()); Assert.assertEquals(spec.getDescription(), spec1.getDescription()); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index cad3794134..50acb7f927 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -76,7 +76,7 @@ public void writeTo() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index f424b3f624..d8f8d6da94 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -465,12 +465,25 @@ public static Map constructToolParams( ) { Map toolParams = new HashMap<>(); Map toolSpecParams = toolSpecMap.get(action).getParameters(); + Map toolSpecConfigMap = toolSpecMap.get(action).getConfigMap(); if (toolSpecParams != null) { toolParams.putAll(toolSpecParams); } + if (toolSpecConfigMap != null) { + toolParams.putAll(toolSpecConfigMap); + } if (tools.get(action).useOriginalInput()) { toolParams.put("input", question); lastActionInput.set(question); + } else if (toolSpecConfigMap != null && toolSpecConfigMap.containsKey("input")) { + String input = toolSpecConfigMap.get("input"); + StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}"); + input = substitutor.replace(input); + toolParams.put("input", input); + if (isJson(input)) { + Map params = getParameterMap(gson.fromJson(input, Map.class)); + toolParams.putAll(params); + } } else { toolParams.put("input", actionInput); if (isJson(actionInput)) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 672890c030..3891caf8e7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -428,12 +428,18 @@ Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + params.put("question", "raw input"); + doReturn(false).when(firstTool).useOriginalInput(); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + // The value of input should be "config_value". + assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + + @Test + public void testToolConfigWithInputPlaceholder() { + // Mock tool validation to return false. + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with a tool including two parameters. + MLAgent mlAgent = createMLAgentWithToolsConfig(ImmutableMap.of("input", "${parameters.key2}")); + + // Create parameters for the agent. + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + params.put("question", "raw input"); + doReturn(false).when(firstTool).useOriginalInput(); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + // The value of input should be replaced with the value associated with the key "key2" of the first tool. + assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + @Test public void testSaveLastTraceFailure() { // Mock tool validation to return true. @@ -838,6 +898,25 @@ private MLAgent createMLAgentWithTools() { .build(); } + private MLAgent createMLAgentWithToolsConfig(Map configMap) { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .configMap(configMap) + .build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .memory(mlMemorySpec) + .llm(llmSpec) + .build(); + } + private Map createAgentParamsWithAction(String action, String actionInput) { Map params = new HashMap<>(); params.put("action", action); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index 609609438a..b0225abc49 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -300,6 +300,26 @@ public void testGetToolExecuteParams() { assertFalse(result.containsKey("toolType.param2")); } + @Test + public void testGetToolExecuteParamsWithConfig() { + MLToolSpec toolSpec = mock(MLToolSpec.class); + when(toolSpec.getParameters()).thenReturn(Map.of("param1", "value1", "tool_key", "value_from_parameters")); + when(toolSpec.getConfigMap()).thenReturn(Map.of("tool_key", "tool_config_value")); + when(toolSpec.getType()).thenReturn("toolType"); + when(toolSpec.getName()).thenReturn("toolName"); + + Map params = Map + .of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4", "toolName.tool_key", "dynamic value"); + + Map result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params); + + assertEquals("value1", result.get("param1")); + assertEquals("value3", result.get("param3")); + assertEquals("value4", result.get("param4")); + assertFalse(result.containsKey("toolType.param2")); + assertEquals("tool_config_value", result.get("tool_key")); + } + @Test public void testGetToolExecuteParamsWithInputSubstitution() { // Setup ToolSpec with parameters diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 8a0ab62168..d5e2d40e50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -282,7 +282,7 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden) throws IOExc MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH,