Skip to content

Commit

Permalink
Abstract tool support to create ML agent tools
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>
  • Loading branch information
arjunkumargiri committed Nov 14, 2023
1 parent 2b04af5 commit 27af0c7
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 163 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.opensearch.ml.engine.tools;

import lombok.Getter;
import lombok.Setter;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;

import java.util.Map;

public abstract class AbstractTool implements Tool {

/**
* Name of the tool to be used in prompt.
*/
@Setter
@Getter
private String name;

/**
* Default description of the tool. This description will be used by LLM to select next tool to execute.
*/
@Getter
@Setter
private String description;

/**
* Tool type mapping to the corresponding run function. Tool type will be used by agent framework to identify the tool.
*/
@Getter
private String type;

/**
* Current tool version.
*/
@Getter
protected String version;

/**
* Parser used to read tool input.
*/
@Setter
protected Parser inputParser;

/**
* Parser used to write tool output.
*/
@Setter
protected Parser outputParser;

/**
* Default tool constructor.
*
* @param type
* @param name
* @param description
*/
protected AbstractTool(String type, String name, String description) {
this.type = type;
this.name = name;
this.description = description;
}

protected AbstractTool(String type, String description) {
this(type, type, description);
}

/**
* Validate tool input and check if request could be processed by the tool.
*
* @param parameters
* @return
*/
@Override
public abstract boolean validate(Map<String, String> parameters);

}

Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.ml.engine.tools;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
Expand All @@ -27,19 +25,15 @@
*/
@Log4j2
@ToolAnnotation(AgentTool.TYPE)
public class AgentTool implements Tool {
public class AgentTool extends AbstractTool {
public static final String TYPE = "AgentTool";
private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent.";
private final Client client;

private String agentId;
@Setter @Getter
private String name = TYPE;

private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent.";
@Getter @Setter
private String description = DEFAULT_DESCRIPTION;

public AgentTool(Client client, String agentId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.agentId = agentId;
}
Expand All @@ -58,26 +52,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.ml.engine.tools;

import lombok.Getter;
import lombok.Setter;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.admin.cluster.health.ClusterHealthRequest;
import org.opensearch.action.admin.cluster.health.ClusterHealthResponse;
Expand Down Expand Up @@ -49,28 +47,16 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;

@ToolAnnotation(CatIndexTool.TYPE)
public class CatIndexTool implements Tool {
public class CatIndexTool extends AbstractTool {
public static final String TYPE = "CatIndexTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information.";

@Setter
@Getter
private String name = CatIndexTool.TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String version;

private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;
@SuppressWarnings("unused")
private ClusterService clusterService;

public CatIndexTool(Client client, ClusterService clusterService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.clusterService = clusterService;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.ml.engine.tools;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
Expand All @@ -31,22 +29,14 @@
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool extends AbstractTool {
public static final String TYPE = "MLModelTool";

@Setter @Getter
private String name = TYPE;
private static String DEFAULT_DESCRIPTION = "Use this tool to run any model.";
@Getter @Setter
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;

public MLModelTool(Client client, String modelId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.modelId = modelId;

Expand Down Expand Up @@ -78,26 +68,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}));
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.ml.engine.tools;

import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.spi.tools.Tool;
Expand All @@ -20,20 +19,15 @@
import static org.opensearch.ml.engine.utils.ScriptUtils.executeScript;

@ToolAnnotation(MathTool.TYPE)
public class MathTool implements Tool {
public class MathTool extends AbstractTool {
public static final String TYPE = "MathTool";

@Setter @Getter
private String name = TYPE;
private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem.";

@Setter
private ScriptService scriptService;

private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem.";
@Getter @Setter
private String description = DEFAULT_DESCRIPTION;

public MathTool(ScriptService scriptService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.scriptService = scriptService;
}

Expand All @@ -56,26 +50,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T)result);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,15 @@

@Log4j2
@ToolAnnotation(PainlessScriptTool.TYPE)
public class PainlessScriptTool implements Tool {
public class PainlessScriptTool extends AbstractTool {
public static final String TYPE = "PainlessScriptTool";

@Setter @Getter
private String name = TYPE;
private static String DEFAULT_DESCRIPTION = "Use this tool to get index information.";
@Getter @Setter
private String description = DEFAULT_DESCRIPTION;

private Client client;
private String modelId;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;
private ScriptService scriptService;

public PainlessScriptTool(Client client, ScriptService scriptService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.scriptService = scriptService;

Expand All @@ -64,26 +56,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T)s);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Loading

0 comments on commit 27af0c7

Please sign in to comment.