Skip to content

Commit

Permalink
Resolves #671, Adjusted method signature mapping to json schema, to a…
Browse files Browse the repository at this point in the history
…llow collections in tool arguments
  • Loading branch information
Tarjei400 committed Nov 3, 2024
1 parent ed12693 commit b8417fb
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.FieldInfo;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassVisitor;
Expand Down Expand Up @@ -406,11 +408,27 @@ private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOu

private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) {
Type type = parameter.type();
DotName typeName = parameter.type().name();

AnnotationInstance pInstance = parameter.annotation(P);

JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString());

return toJsonSchemaProperties(type, index, description);
}

private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView index, JsonSchemaProperty description) {
DotName typeName = type.name();

if (type.kind() == Type.Kind.WILDCARD_TYPE) {
Type boundType = type.asWildcardType().extendsBound();
if (boundType == null) {
boundType = type.asWildcardType().superBound();
}
if (boundType != null) {
return toJsonSchemaProperties(boundType, index, description);
} else {
throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type);
}
}
if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName)
|| DotNames.PRIMITIVE_CHAR.equals(typeName)) {
return removeNulls(STRING, description);
Expand All @@ -435,17 +453,64 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo
return removeNulls(NUMBER, description);
}

if ((type.kind() == Type.Kind.ARRAY)
|| DotNames.LIST.equals(typeName)
|| DotNames.SET.equals(typeName)) { // TODO something else?
return removeNulls(ARRAY, description); // TODO provide type of array?
// TODO something else?
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
: null;

Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0)
: type.asArrayType().component();

Iterable<JsonSchemaProperty> elementProperties = toJsonSchemaProperties(elementType, index, null);

JsonSchemaProperty itemsSchema;
if (isComplexType(elementType)) {
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : elementProperties) {
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
}
itemsSchema = JsonSchemaProperty.from("items", fieldDescription);
} else {
itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next());
}

return removeNulls(ARRAY, itemsSchema, description);
}

if (isEnum(type, index)) {
return removeNulls(STRING, enums(enumConstants(type)), description);
}

return removeNulls(OBJECT, description); // TODO provide internals
if (type.kind() == Type.Kind.CLASS) {
Map<String, Object> properties = new HashMap<>();
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();
if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : fieldSchema) {
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
}

properties.put(fieldName, fieldDescription);
}
}

JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties);
return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description);
}

throw new IllegalArgumentException("Unsupported type: " + type);
}

private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.acme.example.openai.aiservices;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -26,6 +27,18 @@ public AssistantWithToolsResource(Assistant assistant) {
this.assistant = assistant;
}

public static class TestData {
String foo;
Integer bar;
Double baz;

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
}
}

@GET
public String get(@RestQuery String message) {
return assistant.chat(message);
Expand Down Expand Up @@ -54,6 +67,25 @@ int add(int a, int b) {
double sqrt(int x) {
return Math.sqrt(x);
}

@Tool("Calculates the the sum of all provided numbers")
double sumAll(List<Double> x) {

return x.stream().reduce(0.0, (a, b) -> a + b);
}

@Tool("Evaluate test data object")
public TestData evaluateTestObject(List<TestData> data) {
return new TestData("Empty", 0, 0.0);
}

@Tool("Calculates all factors of the provided integer.")
List<Integer> getFactors(int x) {
return java.util.stream.IntStream.rangeClosed(1, x)
.filter(i -> x % i == 0)
.boxed()
.toList();
}
}

@RequestScoped
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.acme.example.openai.aiservices;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;

import java.net.URL;

import org.junit.jupiter.api.Test;

import io.quarkus.test.common.http.TestHTTPEndpoint;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.test.junit.QuarkusTest;

@QuarkusTest
public class AssistantResourceWithToolsTest {

@TestHTTPEndpoint(AssistantWithToolsResource.class)
@TestHTTPResource
URL url;

@Test
public void get() {
given()
.baseUri(url.toString())
.queryParam("message", "This is a test")
.get()
.then()
.statusCode(200)
.body(containsString("MockGPT"));
}
}

0 comments on commit b8417fb

Please sign in to comment.