diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 018005a3..f7bc3311 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -11,6 +11,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -62,29 +63,55 @@ public class CreateAlertTool implements Tool { private String description = DEFAULT_DESCRIPTION; private final Client client; + @Getter private final String modelId; - private final String TOOL_PROMPT_TEMPLATE; + @Getter + private final String modelType; + @Getter + private final String toolPrompt; private static final String MODEL_ID = "model_id"; private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json"; private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context"; private static final Map promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH); - public CreateAlertTool(Client client, String modelId, String modelType) { + public enum ModelType { + CLAUDE, + OPENAI; + + public static ModelType from(String value) { + if (value.isEmpty()) { + return ModelType.CLAUDE; + } + try { + return ModelType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + log.error("Wrong Model type, should be CLAUDE or OPENAI"); + return ModelType.CLAUDE; + } + } + } + + public CreateAlertTool(Client client, String modelId, String modelType, String prompt) { this.client = client; this.modelId = modelId; - if (!promptDict.containsKey(modelType)) { - throw new IllegalArgumentException( - LoggerMessageFormat - .format( - null, - "Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]", - modelType, - String.join(",", promptDict.keySet()) - ) - ); + this.modelType = String.valueOf(ModelType.from(modelType)); + if (prompt.isEmpty()) { + if (!promptDict.containsKey(this.modelType)) { + throw new IllegalArgumentException( + LoggerMessageFormat + .format( + null, + "Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]", + modelType, + String.join(",", promptDict.keySet()) + ) + ); + } + this.toolPrompt = promptDict.get(this.modelType); + } else { + this.toolPrompt = prompt; } - TOOL_PROMPT_TEMPLATE = promptDict.get(modelType); } @Override @@ -205,7 +232,7 @@ private ActionRequest constructMLPredictRequest(Map tmpParams, S tmpParams.putIfAbsent("chat_history", ""); tmpParams.putIfAbsent("question", DEFAULT_QUESTION); // In case no question is provided, use a default question. StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}"); - String finalToolPrompt = substitute.replace(TOOL_PROMPT_TEMPLATE); + String finalToolPrompt = substitute.replace(toolPrompt); tmpParams.put("prompt", finalToolPrompt); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build(); @@ -279,7 +306,8 @@ public CreateAlertTool create(Map params) { throw new IllegalArgumentException("model_id cannot be null or blank."); } String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString()); - return new CreateAlertTool(client, modelId, modelType); + String prompt = (String) params.getOrDefault("prompt", ""); + return new CreateAlertTool(client, modelId, modelType, prompt); } @Override diff --git a/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java index f0b6a245..6df57965 100644 --- a/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java @@ -159,16 +159,32 @@ public void testTool_WithBlankModelId() { @Test public void testTool_WithNonSupportedModelType() { - Exception exception = assertThrows( - IllegalArgumentException.class, - () -> CreateAlertTool.Factory - .getInstance() - .create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType")) - ); - assertEquals( - "Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]", - exception.getMessage() - ); + CreateAlertTool alertTool = CreateAlertTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType")); + assertEquals("CLAUDE", alertTool.getModelType()); + } + + @Test + public void testTool_WithEmptyModelType() { + CreateAlertTool alertTool = CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "")); + assertEquals("CLAUDE", alertTool.getModelType()); + } + + @Test + public void testToolWithCustomPrompt() { + CreateAlertTool tool = CreateAlertTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt")); + assertEquals(CreateAlertTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("custom prompt", tool.getToolPrompt()); + + tool + .run( + ImmutableMap.of("indices", mockedIndexName), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), log::info) + ); } @Test