Skip to content

Commit

Permalink
Removed unnecessary gradle dependencies
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>
  • Loading branch information
arjunkumargiri committed Dec 14, 2023
1 parent 07fdfeb commit 8a73866
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -71,10 +70,6 @@ public class RAGTool extends AbstractRetrieverTool {
private String embeddingModelId;
private Integer docSize;
private Integer k;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;

@Builder
public RAGTool(
Expand All @@ -88,7 +83,7 @@ public RAGTool(
String embeddingModelId,
String modelId
) {
super(client, xContentRegistry, index, sourceFields, docSize);
super(TYPE, DEFAULT_DESCRIPTION, client, xContentRegistry, index, sourceFields, docSize);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
Expand All @@ -99,13 +94,10 @@ public RAGTool(
this.k = k == null ? DEFAULT_K : k;
this.modelId = modelId;

outputParser = new Parser() {
@Override
public Object parse(Object o) {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser(o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand Down Expand Up @@ -195,10 +187,10 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.<MLTaskResponse>wrap(resp -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput();
modelTensorOutput.getMlModelOutputs();
if (outputParser == null) {
if (this.getOutputParser() == null) {
listener.onResponse((T) modelTensorOutput.getMlModelOutputs());
} else {
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
listener.onResponse((T) this.getOutputParser().parse(modelTensorOutput.getMlModelOutputs()));
}
}, e -> {
log.error("Failed to run model " + modelId, e);
Expand All @@ -214,21 +206,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
9 changes: 0 additions & 9 deletions spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import org.opensearch.gradle.test.RestIntegTestTask

plugins {
id 'java'
id "io.freefair.lombok"
id 'jacoco'
id 'com.github.johnrengelman.shadow'
id 'maven-publish'
Expand Down Expand Up @@ -72,14 +71,6 @@ test {
systemProperty 'tests.security.manager', 'false'
}

task integTest(type: RestIntegTestTask) {
description 'Run integ test with opensearch test framework'
group 'verification'
systemProperty 'tests.security.manager', 'false'
dependsOn test
}
check.dependsOn integTest

task sourcesJar(type: Jar) {
archiveClassifier.set 'sources'
from sourceSets.main.allJava
Expand Down

0 comments on commit 8a73866

Please sign in to comment.