Skip to content

Commit

Permalink
Add abstract tool support
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>
  • Loading branch information
arjunkumargiri committed Dec 13, 2023
1 parent 44118ef commit dcd4adf
Show file tree
Hide file tree
Showing 18 changed files with 269 additions and 312 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -36,7 +37,7 @@
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public abstract class AbstractRetrieverTool extends AbstractTool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
Expand All @@ -51,12 +52,15 @@ public abstract class AbstractRetrieverTool implements Tool {
protected Integer docSize;

protected AbstractRetrieverTool(
String type,
String description,
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
super(type, description);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,29 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports running any Agent.
*/
@Log4j2
@ToolAnnotation(AgentTool.TYPE)
public class AgentTool implements Tool {
public class AgentTool extends AbstractTool {
public static final String TYPE = "AgentTool";
private final Client client;

private String agentId;
@Setter
@Getter
private String name = TYPE;

private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

public AgentTool(Client client, String agentId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.agentId = agentId;
}
Expand All @@ -66,26 +60,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,47 +43,30 @@
import org.opensearch.core.action.ActionResponse;
import org.opensearch.index.IndexSettings;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(CatIndexTool.TYPE)
public class CatIndexTool implements Tool {
public class CatIndexTool extends AbstractTool {
public static final String TYPE = "CatIndexTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information.";

@Setter
@Getter
private String name = CatIndexTool.TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String version;

private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;
@SuppressWarnings("unused")
private ClusterService clusterService;

public CatIndexTool(Client client, ClusterService clusterService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.clusterService = clusterService;

outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser((Parser<Object, Object>) parser -> {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) parser;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand Down Expand Up @@ -295,16 +278,6 @@ public void onFailure(final Exception e) {
}, size);
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,27 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(IndexMappingTool.NAME)
public class IndexMappingTool implements Tool {
public static final String NAME = "IndexMappingTool";
@ToolAnnotation(IndexMappingTool.TYPE)
public class IndexMappingTool extends AbstractTool {
public static final String TYPE = "IndexMappingTool";

private static final String DEFAULT_DESCRIPTION = "Use this tool to get index mapping information.";
@Setter
@Getter
private String name = IndexMappingTool.NAME;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String type;
@Getter
private String version;
private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;

public IndexMappingTool(Client client) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;

outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser((Parser<Object, Object>) o -> {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,36 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports running any ml-commons model.
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool extends AbstractTool {
public static final String TYPE = "MLModelTool";

@Setter
@Getter
private String name = TYPE;
private static String DEFAULT_DESCRIPTION = "Use this tool to run any model.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;

public MLModelTool(Client client, String modelId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.modelId = modelId;

outputParser = new Parser() {
@Override
public Object parse(Object o) {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser(o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand All @@ -72,37 +58,17 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.<MLTaskResponse>wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
modelTensorOutput.getMlModelOutputs();
if (outputParser == null) {
if (this.getOutputParser() == null) {
listener.onResponse((T) modelTensorOutput.getMlModelOutputs());
} else {
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
listener.onResponse((T) this.getOutputParser().parse(modelTensorOutput.getMlModelOutputs()));
}
}, e -> {
log.error("Failed to run model " + modelId, e);
listener.onFailure(e);
}));
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,24 @@
import java.util.regex.Pattern;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.script.ScriptService;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(MathTool.TYPE)
public class MathTool implements Tool {
public class MathTool extends AbstractTool {
public static final String TYPE = "MathTool";

@Setter
@Getter
private String name = TYPE;

@Setter
private ScriptService scriptService;

private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

public MathTool(ScriptService scriptService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.scriptService = scriptService;
}

Expand All @@ -59,26 +52,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T) result);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
try {
Expand Down
Loading

0 comments on commit dcd4adf

Please sign in to comment.