diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 6d03601d08..ae629e80ab 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -486,7 +486,12 @@ private void runReAct( if (toolSpecParams != null) { toolParams.putAll(toolSpecParams); } - toolParams.put("input", actionInput); + if (tools.get(action).useOriginalInput()) { + toolParams.put("input", question); + lastActionInput.set(question); + } else { + toolParams.put("input", actionInput); + } if (tools.get(action).validate(toolParams)) { try { String finalAction = action; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 5eaa10306a..0518846a36 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -12,6 +12,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -114,6 +115,8 @@ public class MLChatAgentRunnerTest { private ArgumentCaptor> conversationIndexMemoryCapture; @Captor private ArgumentCaptor> mlMemoryManagerCapture; + @Captor + private ArgumentCaptor> ToolParamsCapture; @Before @SuppressWarnings("unchecked") @@ -690,6 +693,35 @@ public void testToolParameters() { assertNotNull(modelTensorOutput); } + @Test + public void testToolUseOriginalInput() { + // Mock tool validation to return false. + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with a tool including two parameters. + MLAgent mlAgent = createMLAgentWithTools(); + + // Create parameters for the agent. + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + params.put("question", "raw input"); + doReturn(true).when(firstTool).useOriginalInput(); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(3, ((Map) argumentCaptor.getValue()).size()); + assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + @Test public void testSaveLastTraceFailure() { // Mock tool validation to return true.