Skip to content

Commit

Permalink
add config field in MLToolSpec for static parameters (#2977)
Browse files Browse the repository at this point in the history
* add config field in MLToolSpec for static parameters

Signed-off-by: Jing Zhang <jngz@amazon.com>

* add version control

Signed-off-by: Jing Zhang <jngz@amazon.com>

* address comments I

Signed-off-by: Jing Zhang <jngz@amazon.com>

* address commits II

Signed-off-by: Jing Zhang <jngz@amazon.com>

* address comments III

Signed-off-by: Jing Zhang <jngz@amazon.com>

---------

Signed-off-by: Jing Zhang <jngz@amazon.com>
  • Loading branch information
jngz-es authored Oct 21, 2024
1 parent 1bbaddd commit 9ed0040
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,5 @@ public class CommonValue {
public static final Version VERSION_2_15_0 = Version.fromString("2.15.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
public static final Version VERSION_2_17_0 = Version.fromString("2.17.0");
public static final Version VERSION_2_18_0 = Version.fromString("2.18.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import java.io.IOException;
import java.util.Map;

import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.CommonValue;

import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -24,20 +26,31 @@
@EqualsAndHashCode
@Getter
public class MLToolSpec implements ToXContentObject {
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_18_0;

public static final String TOOL_TYPE_FIELD = "type";
public static final String TOOL_NAME_FIELD = "name";
public static final String DESCRIPTION_FIELD = "description";
public static final String PARAMETERS_FIELD = "parameters";
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";
public static final String CONFIG_FIELD = "config";

private String type;
private String name;
private String description;
private Map<String, String> parameters;
private boolean includeOutputInAgentResponse;
private Map<String, String> configMap;

@Builder(toBuilder = true)
public MLToolSpec(String type, String name, String description, Map<String, String> parameters, boolean includeOutputInAgentResponse) {
public MLToolSpec(
String type,
String name,
String description,
Map<String, String> parameters,
boolean includeOutputInAgentResponse,
Map<String, String> configMap
) {
if (type == null) {
throw new IllegalArgumentException("tool type is null");
}
Expand All @@ -46,6 +59,7 @@ public MLToolSpec(String type, String name, String description, Map<String, Stri
this.description = description;
this.parameters = parameters;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
this.configMap = configMap;
}

public MLToolSpec(StreamInput input) throws IOException {
Expand All @@ -56,6 +70,9 @@ public MLToolSpec(StreamInput input) throws IOException {
parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
includeOutputInAgentResponse = input.readBoolean();
if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) {
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
}

public void writeTo(StreamOutput out) throws IOException {
Expand All @@ -69,6 +86,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeBoolean(includeOutputInAgentResponse);
if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG)) {
if (configMap != null) {
out.writeBoolean(true);
out.writeMap(configMap, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand All @@ -87,6 +112,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(PARAMETERS_FIELD, parameters);
}
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
if (configMap != null && !configMap.isEmpty()) {
builder.field(CONFIG_FIELD, configMap);
}
builder.endObject();
return builder;
}
Expand All @@ -97,6 +125,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
String description = null;
Map<String, String> parameters = null;
boolean includeOutputInAgentResponse = false;
Map<String, String> configMap = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -119,6 +148,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = parser.booleanValue();
break;
case CONFIG_FIELD:
configMap = getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
Expand All @@ -131,6 +163,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
.description(description)
.parameters(parameters)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.configMap(configMap)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void constructor_NullName() {
MLAgentType.CONVERSATIONAL.name(),
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
null,
null,
Instant.EPOCH,
Expand All @@ -66,7 +66,7 @@ public void constructor_NullType() {
null,
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
null,
null,
Instant.EPOCH,
Expand All @@ -86,7 +86,7 @@ public void constructor_NullLLMSpec() {
MLAgentType.CONVERSATIONAL.name(),
"test",
null,
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
null,
null,
Instant.EPOCH,
Expand All @@ -100,7 +100,14 @@ public void constructor_NullLLMSpec() {
public void constructor_DuplicateTool() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Duplicate tool defined: test_tool_name");
MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false);
MLToolSpec mlToolSpec = new MLToolSpec(
"test_tool_type",
"test_tool_name",
"test",
Collections.emptyMap(),
false,
Collections.emptyMap()
);
MLAgent agent = new MLAgent(
"test_name",
MLAgentType.CONVERSATIONAL.name(),
Expand All @@ -123,7 +130,7 @@ public void writeTo() throws IOException {
"CONVERSATIONAL",
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
Map.of("test", "test"),
new MLMemorySpec("test", "123", 0),
Instant.EPOCH,
Expand All @@ -150,7 +157,7 @@ public void writeTo_NullLLM() throws IOException {
"FLOW",
"test",
null,
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
Map.of("test", "test"),
new MLMemorySpec("test", "123", 0),
Instant.EPOCH,
Expand Down Expand Up @@ -194,7 +201,7 @@ public void writeTo_NullParameters() throws IOException {
MLAgentType.CONVERSATIONAL.name(),
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
null,
new MLMemorySpec("test", "123", 0),
Instant.EPOCH,
Expand All @@ -216,7 +223,7 @@ public void writeTo_NullMemory() throws IOException {
"CONVERSATIONAL",
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
Map.of("test", "test"),
null,
Instant.EPOCH,
Expand All @@ -238,7 +245,7 @@ public void toXContent() throws IOException {
"CONVERSATIONAL",
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)),
List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap())),
Map.of("test", "test"),
new MLMemorySpec("test", "123", 0),
Instant.EPOCH,
Expand Down Expand Up @@ -294,7 +301,7 @@ public void fromStream() throws IOException {
MLAgentType.CONVERSATIONAL.name(),
"test",
new LLMSpec("test_model", Map.of("test_key", "test_value")),
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
Map.of("test", "test"),
new MLMemorySpec("test", "123", 0),
Instant.EPOCH,
Expand Down
Loading

0 comments on commit 9ed0040

Please sign in to comment.