diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java index 587dfeb7f9..afc5d2eff6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AbstractRetrieverTool.java @@ -22,6 +22,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -36,7 +37,7 @@ @Log4j2 @Getter @Setter -public abstract class AbstractRetrieverTool implements Tool { +public abstract class AbstractRetrieverTool extends AbstractTool { public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; public static final String INPUT_FIELD = "input"; public static final String INDEX_FIELD = "index"; @@ -51,12 +52,15 @@ public abstract class AbstractRetrieverTool implements Tool { protected Integer docSize; protected AbstractRetrieverTool( + String type, + String description, Client client, NamedXContentRegistry xContentRegistry, String index, String[] sourceFields, Integer docSize ) { + super(type, description); this.client = client; this.xContentRegistry = xContentRegistry; this.index = index; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 04d6942773..9159e12b97 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -14,13 +14,12 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; -import lombok.Getter; -import lombok.Setter; import lombok.extern.log4j.Log4j2; /** @@ -28,21 +27,16 @@ */ @Log4j2 @ToolAnnotation(AgentTool.TYPE) -public class AgentTool implements Tool { +public class AgentTool extends AbstractTool { public static final String TYPE = "AgentTool"; private final Client client; private String agentId; - @Setter - @Getter - private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent."; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; public AgentTool(Client client, String agentId) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; this.agentId = agentId; } @@ -66,26 +60,6 @@ public void run(Map parameters, ActionListener listener) } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - @Override public boolean validate(Map parameters) { return true; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index c74f53ba17..4e3f973e19 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -43,47 +43,30 @@ import org.opensearch.core.action.ActionResponse; import org.opensearch.index.IndexSettings; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import lombok.Getter; -import lombok.Setter; - @ToolAnnotation(CatIndexTool.TYPE) -public class CatIndexTool implements Tool { +public class CatIndexTool extends AbstractTool { public static final String TYPE = "CatIndexTool"; private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information."; - @Setter - @Getter - private String name = CatIndexTool.TYPE; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; - @Getter - private String version; - private Client client; - @Setter - private Parser inputParser; - @Setter - private Parser outputParser; @SuppressWarnings("unused") private ClusterService clusterService; public CatIndexTool(Client client, ClusterService clusterService) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; this.clusterService = clusterService; - outputParser = new Parser<>() { - @Override - public Object parse(Object o) { - @SuppressWarnings("unchecked") - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; + this.setOutputParser((Parser) parser -> { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) parser; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }); } @Override @@ -295,16 +278,6 @@ public void onFailure(final Exception e) { }, size); } - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java index 99a7955cd0..a7045dee42 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java @@ -23,45 +23,27 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import lombok.Getter; -import lombok.Setter; - -@ToolAnnotation(IndexMappingTool.NAME) -public class IndexMappingTool implements Tool { - public static final String NAME = "IndexMappingTool"; +@ToolAnnotation(IndexMappingTool.TYPE) +public class IndexMappingTool extends AbstractTool { + public static final String TYPE = "IndexMappingTool"; private static final String DEFAULT_DESCRIPTION = "Use this tool to get index mapping information."; - @Setter - @Getter - private String name = IndexMappingTool.NAME; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; - @Getter - private String type; - @Getter - private String version; private Client client; - @Setter - private Parser inputParser; - @Setter - private Parser outputParser; public IndexMappingTool(Client client) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; - outputParser = new Parser<>() { - @Override - public Object parse(Object o) { - @SuppressWarnings("unchecked") - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; + this.setOutputParser((Parser) o -> { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }); } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 0855741f88..6f1fb0b3e9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -16,15 +16,13 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import lombok.Getter; -import lombok.Setter; import lombok.extern.log4j.Log4j2; /** @@ -32,34 +30,22 @@ */ @Log4j2 @ToolAnnotation(MLModelTool.TYPE) -public class MLModelTool implements Tool { +public class MLModelTool extends AbstractTool { public static final String TYPE = "MLModelTool"; - @Setter - @Getter - private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to run any model."; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; private Client client; private String modelId; - @Setter - private Parser inputParser; - @Setter - private Parser outputParser; public MLModelTool(Client client, String modelId) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; this.modelId = modelId; - outputParser = new Parser() { - @Override - public Object parse(Object o) { - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; + this.setOutputParser(o -> { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }); } @Override @@ -72,10 +58,10 @@ public void run(Map parameters, ActionListener listener) client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); modelTensorOutput.getMlModelOutputs(); - if (outputParser == null) { + if (this.getOutputParser() == null) { listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); } else { - listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + listener.onResponse((T) this.getOutputParser().parse(modelTensorOutput.getMlModelOutputs())); } }, e -> { log.error("Failed to run model " + modelId, e); @@ -83,26 +69,6 @@ public void run(Map parameters, ActionListener listener) })); } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java index 9262559071..0b395b6f78 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java @@ -12,31 +12,24 @@ import java.util.regex.Pattern; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.script.ScriptService; -import lombok.Getter; import lombok.Setter; @ToolAnnotation(MathTool.TYPE) -public class MathTool implements Tool { +public class MathTool extends AbstractTool { public static final String TYPE = "MathTool"; - - @Setter - @Getter - private String name = TYPE; - @Setter private ScriptService scriptService; private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem."; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; public MathTool(ScriptService scriptService) { + super(TYPE, DEFAULT_DESCRIPTION); this.scriptService = scriptService; } @@ -59,26 +52,6 @@ public void run(Map parameters, ActionListener listener) listener.onResponse((T) result); } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - @Override public boolean validate(Map parameters) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java index 33a3f45fb1..f9ee4d3f61 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/NeuralSparseTool.java @@ -30,7 +30,6 @@ public class NeuralSparseTool extends AbstractRetrieverTool { public static final String TYPE = "NeuralSparseTool"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; - private String name = TYPE; private String modelId; private String embeddingField; @@ -45,7 +44,7 @@ public NeuralSparseTool( Integer docSize, String modelId ) { - super(client, xContentRegistry, index, sourceFields, docSize); + super(TYPE, DEFAULT_DESCRIPTION, client, xContentRegistry, index, sourceFields, docSize); this.modelId = modelId; this.embeddingField = embeddingField; } @@ -67,21 +66,6 @@ protected String getQueryBody(String queryText) { + " }"; } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - public static class Factory extends AbstractRetrieverTool.Factory { private static Factory INSTANCE; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java index f320c5a0ee..529c28ea28 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java @@ -13,47 +13,31 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.engine.utils.ScriptUtils; import org.opensearch.script.ScriptService; -import lombok.Getter; -import lombok.Setter; import lombok.extern.log4j.Log4j2; @Log4j2 @ToolAnnotation(PainlessScriptTool.TYPE) -public class PainlessScriptTool implements Tool { +public class PainlessScriptTool extends AbstractTool { public static final String TYPE = "PainlessScriptTool"; - - @Setter - @Getter - private String name = TYPE; private static String DEFAULT_DESCRIPTION = "Use this tool to get index information."; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; private Client client; - private String modelId; - @Setter - private Parser inputParser; - @Setter - private Parser outputParser; private ScriptService scriptService; public PainlessScriptTool(Client client, ScriptService scriptService) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; this.scriptService = scriptService; - outputParser = new Parser() { - @Override - public Object parse(Object o) { - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; + this.setOutputParser(o -> { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }); } @Override @@ -64,26 +48,6 @@ public void run(Map parameters, ActionListener listener) listener.onResponse((T) s); } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchAlertsTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchAlertsTool.java index 39006dff5e..bb90a15a49 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchAlertsTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchAlertsTool.java @@ -19,6 +19,7 @@ import org.opensearch.commons.alerting.model.Table; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; @@ -27,39 +28,26 @@ import lombok.Setter; @ToolAnnotation(SearchAlertsTool.TYPE) -public class SearchAlertsTool implements Tool { +public class SearchAlertsTool extends AbstractTool { public static final String TYPE = "SearchAlertsTool"; private static final String DEFAULT_DESCRIPTION = "Use this tool to search alerts."; @Setter @Getter private String name = TYPE; - @Getter - @Setter - private String description = DEFAULT_DESCRIPTION; - @Getter - private String type; - @Getter - private String version; private Client client; - @Setter - private Parser inputParser; - @Setter - private Parser outputParser; public SearchAlertsTool(Client client) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; // probably keep this overridden output parser. need to ensure the output matches what's expected - outputParser = new Parser<>() { - @Override - public Object parse(Object o) { - @SuppressWarnings("unchecked") - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; + this.setOutputParser((Parser) o -> { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }); } @Override @@ -121,11 +109,6 @@ public boolean validate(Map parameters) { return true; } - @Override - public String getType() { - return TYPE; - } - /** * Factory for the {@link SearchAlertsTool} */ diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index 468bf8f7d1..13b8e3c42c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -46,7 +46,7 @@ public VectorDBTool( Integer docSize, String modelId ) { - super(client, xContentRegistry, index, sourceFields, docSize); + super(TYPE, DEFAULT_DESCRIPTION, client, xContentRegistry, index, sourceFields, docSize); this.modelId = modelId; this.embeddingField = embeddingField; this.k = k == null ? 10 : k; @@ -71,21 +71,6 @@ protected String getQueryBody(String queryText) { + " }"; } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public void setName(String s) { - this.name = s; - } - public static class Factory extends AbstractRetrieverTool.Factory { private static Factory INSTANCE; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java index 91c3237ac8..618a069efe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java @@ -18,6 +18,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.spi.tools.AbstractTool; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; @@ -31,7 +32,7 @@ @Log4j2 @ToolAnnotation(VisualizationsTool.TYPE) -public class VisualizationsTool implements Tool { +public class VisualizationsTool extends AbstractTool { public static final String NAME = "Find Visualizations"; public static final String TYPE = "VisualizationTool"; public static final String VERSION = "v1.0"; @@ -39,15 +40,14 @@ public class VisualizationsTool implements Tool { public static final String SAVED_OBJECT_TYPE = "visualization"; private static final String DEFAULT_DESCRIPTION = "Use this tool to find user created visualizations. This tool takes the visualization name as input and returns the first 3 matching visualizations"; - private String description = DEFAULT_DESCRIPTION; - private String name = NAME; private final Client client; @Getter private final String index; @Builder public VisualizationsTool(Client client, String index) { + super(TYPE, DEFAULT_DESCRIPTION); this.client = client; this.index = index; } @@ -103,36 +103,6 @@ String trimIdPrefix(String id) { return id; } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return VERSION; - } - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public String getDescription() { - return description; - } - - @Override - public void setDescription(String description) { - this.description = description; - } - @Override public boolean validate(Map parameters) { return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java index f5251498da..48bf72fac9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AbstractRetrieverToolTests.java @@ -60,7 +60,15 @@ public void setup() { AbstractRetrieverTool.class, Mockito .withSettings() - .useConstructor(null, TEST_XCONTENT_REGISTRY_FOR_QUERY, TEST_INDEX, TEST_SOURCE_FIELDS, TEST_DOC_SIZE) + .useConstructor( + "type", + "description", + null, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + TEST_INDEX, + TEST_SOURCE_FIELDS, + TEST_DOC_SIZE + ) .defaultAnswer(Mockito.CALLS_REAL_METHODS) ); when(mockedImpl.getQueryBody(any(String.class))).thenReturn(TEST_QUERY); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java index 9585b0b953..3c1f73fdad 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java @@ -174,7 +174,7 @@ public void testRunAsyncIndexMapping() throws Exception { @Test public void testTool() { Tool tool = IndexMappingTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(IndexMappingTool.NAME, tool.getName()); + assertEquals(IndexMappingTool.TYPE, tool.getName()); assertTrue(tool.validate(indexParams)); assertTrue(tool.validate(otherParams)); assertFalse(tool.validate(emptyParams)); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6556fb639d..2723c05d6e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -528,7 +528,7 @@ public Collection createComponents( toolFactories.put(PainlessScriptTool.TYPE, PainlessScriptTool.Factory.getInstance()); toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); toolFactories.put(SearchAlertsTool.TYPE, SearchAlertsTool.Factory.getInstance()); - toolFactories.put(IndexMappingTool.NAME, IndexMappingTool.Factory.getInstance()); + toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); diff --git a/spi/build.gradle b/spi/build.gradle index 0f3f5a77a4..4c3a8e43f4 100644 --- a/spi/build.gradle +++ b/spi/build.gradle @@ -7,15 +7,15 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowBasePlugin import org.opensearch.gradle.test.RestIntegTestTask plugins { - id 'com.github.johnrengelman.shadow' + id 'java' + id "io.freefair.lombok" id 'jacoco' + id 'com.github.johnrengelman.shadow' id 'maven-publish' + id 'com.diffplug.spotless' version '6.18.0' id 'signing' } -apply plugin: 'opensearch.java' -apply plugin: 'opensearch.testclusters' -apply plugin: 'opensearch.java-rest-test' repositories { mavenLocal() @@ -48,7 +48,8 @@ dependencies { compileOnly "org.opensearch:opensearch:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" - testImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" + testImplementation group: 'junit', name: 'junit', version: '4.13.2' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' } configurations.all { @@ -79,10 +80,6 @@ task integTest(type: RestIntegTestTask) { } check.dependsOn integTest -testClusters.javaRestTest { - testDistribution = 'INTEG_TEST' -} - task sourcesJar(type: Jar) { archiveClassifier.set 'sources' from sourceSets.main.allJava diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/AbstractTool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/AbstractTool.java new file mode 100644 index 0000000000..f833b41c3b --- /dev/null +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/AbstractTool.java @@ -0,0 +1,118 @@ +package org.opensearch.ml.common.spi.tools; + +import java.util.Map; + +/** + * Abstract tool used to simplify tool creation. Concrete Tool implementation needs to be thread safe. + */ +public abstract class AbstractTool implements Tool { + + /** + * Name of the tool to be used in prompt. + */ + private String name; + + /** + * Default description of the tool. This description will be used by LLM to select next tool to execute. + */ + private String description; + + /** + * Tool type mapping to the corresponding run function. Tool type will be used by agent framework to identify the tool. + */ + private String type; + + /** + * Current tool version. + */ + private String version; + + /** + * Parser used to read tool input. + */ + private Parser inputParser; + + /** + * Parser used to write tool output. + */ + private Parser outputParser; + + /** + * Default tool constructor. + * + * @param type + * @param name + * @param description + */ + protected AbstractTool(String type, String name, String description) { + this.type = type; + this.name = name; + this.description = description; + } + + protected AbstractTool(String type, String description) { + this(type, type, description); + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setDescription(String description) { + this.description = description; + } + + @Override + public String getDescription() { + return this.description; + } + + @Override + public String getType() { + return this.type; + } + + public void setVersion(String version) { + this.version = version; + } + + @Override + public String getVersion() { + return this.version; + } + + @Override + public void setInputParser(Parser inputParser) { + this.inputParser = inputParser; + } + + public Parser getInputParser() { + return this.inputParser; + } + + @Override + public void setOutputParser(Parser outputParser) { + this.outputParser = outputParser; + } + + public Parser getOutputParser() { + return this.outputParser; + } + + /** + * Validate tool input and check if request could be processed by the tool. + * + * @param parameters + * @return + */ + @Override + public abstract boolean validate(Map parameters); + +} diff --git a/spi/src/test/java/org.opensearch.ml.common.spi.tools/AbstractToolTest.java b/spi/src/test/java/org.opensearch.ml.common.spi.tools/AbstractToolTest.java new file mode 100644 index 0000000000..ee489942cb --- /dev/null +++ b/spi/src/test/java/org.opensearch.ml.common.spi.tools/AbstractToolTest.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.common.spi.tools; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.Mockito; + +public class AbstractToolTest { + + public static final String TOOL_TYPE = "tool_type"; + public static final String TOOL_NAME = "tool_name"; + public static final String TOOL_DESCRIPTION = "tool_description"; + private AbstractTool abstractTool; + + @Mock + private Parser mockInputParser; + + @Mock + private Parser mockOutputParser; + + @Before + public void setup() { + abstractTool = Mockito.mock(AbstractTool.class, + Mockito.withSettings().useConstructor(TOOL_TYPE, TOOL_DESCRIPTION).defaultAnswer(Mockito.CALLS_REAL_METHODS)); + } + + @Test + public void testConstructorValueIsPersisted() { + Assert.assertEquals(TOOL_TYPE, abstractTool.getType()); + Assert.assertEquals(TOOL_TYPE, abstractTool.getName()); + Assert.assertEquals(TOOL_DESCRIPTION, abstractTool.getDescription()); + } + + @Test + public void testGetterSetterName() { + abstractTool.setName("test_name"); + Assert.assertEquals("test_name", abstractTool.getName()); + } + + @Test + public void testGetterSetterDescription() { + abstractTool.setDescription("test_description"); + Assert.assertEquals("test_description", abstractTool.getDescription()); + } + + + @Test + public void testGetterSetterVersion() { + abstractTool.setVersion("test_version"); + Assert.assertEquals("test_version", abstractTool.getVersion()); + } + + @Test + public void testGetterSetterInputParser() { + abstractTool.setInputParser(mockInputParser); + Assert.assertEquals(mockInputParser, abstractTool.getInputParser()); + } + + @Test + public void testGetterSetterOutputParser() { + abstractTool.setOutputParser(mockOutputParser); + Assert.assertEquals(mockOutputParser, abstractTool.getOutputParser()); + } + + @Test + public void testConstructor() { + abstractTool = Mockito.mock(AbstractTool.class, + Mockito.withSettings().useConstructor(TOOL_TYPE, TOOL_NAME, TOOL_DESCRIPTION).defaultAnswer(Mockito.CALLS_REAL_METHODS)); + Assert.assertEquals(TOOL_TYPE, abstractTool.getType()); + Assert.assertEquals(TOOL_NAME, abstractTool.getName()); + Assert.assertEquals(TOOL_DESCRIPTION, abstractTool.getDescription()); + } + +} diff --git a/spi/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/spi/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 0000000000..ca6ee9cea8 --- /dev/null +++ b/spi/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file