diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 8ce9e95aa..7201e1835 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -26,9 +26,11 @@ import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.DotName; +import org.jboss.jandex.FieldInfo; import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; import org.jboss.jandex.MethodParameterInfo; +import org.jboss.jandex.ParameterizedType; import org.jboss.jandex.Type; import org.jboss.logging.Logger; import org.objectweb.asm.ClassVisitor; @@ -406,11 +408,27 @@ private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOu private Iterable toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) { Type type = parameter.type(); - DotName typeName = parameter.type().name(); - AnnotationInstance pInstance = parameter.annotation(P); + JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString()); + return toJsonSchemaProperties(type, index, description); + } + + private Iterable toJsonSchemaProperties(Type type, IndexView index, JsonSchemaProperty description) { + DotName typeName = type.name(); + + if (type.kind() == Type.Kind.WILDCARD_TYPE) { + Type boundType = type.asWildcardType().extendsBound(); + if (boundType == null) { + boundType = type.asWildcardType().superBound(); + } + if (boundType != null) { + return toJsonSchemaProperties(boundType, index, description); + } else { + throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type); + } + } if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName) || DotNames.PRIMITIVE_CHAR.equals(typeName)) { return removeNulls(STRING, description); @@ -435,17 +453,64 @@ private Iterable toJsonSchemaProperties(MethodParameterInfo return removeNulls(NUMBER, description); } - if ((type.kind() == Type.Kind.ARRAY) - || DotNames.LIST.equals(typeName) - || DotNames.SET.equals(typeName)) { // TODO something else? - return removeNulls(ARRAY, description); // TODO provide type of array? + // TODO something else? + if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) { + ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType() + : null; + + Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0) + : type.asArrayType().component(); + + Iterable elementProperties = toJsonSchemaProperties(elementType, index, null); + + JsonSchemaProperty itemsSchema; + if (isComplexType(elementType)) { + Map fieldDescription = new HashMap<>(); + + for (JsonSchemaProperty fieldProperty : elementProperties) { + fieldDescription.put(fieldProperty.key(), fieldProperty.value()); + } + itemsSchema = JsonSchemaProperty.from("items", fieldDescription); + } else { + itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next()); + } + + return removeNulls(ARRAY, itemsSchema, description); } if (isEnum(type, index)) { return removeNulls(STRING, enums(enumConstants(type)), description); } - return removeNulls(OBJECT, description); // TODO provide internals + if (type.kind() == Type.Kind.CLASS) { + Map properties = new HashMap<>(); + ClassInfo classInfo = index.getClassByName(type.name()); + + List required = new ArrayList<>(); + if (classInfo != null) { + for (FieldInfo field : classInfo.fields()) { + String fieldName = field.name(); + + Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); + Map fieldDescription = new HashMap<>(); + + for (JsonSchemaProperty fieldProperty : fieldSchema) { + fieldDescription.put(fieldProperty.key(), fieldProperty.value()); + } + + properties.put(fieldName, fieldDescription); + } + } + + JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties); + return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description); + } + + throw new IllegalArgumentException("Unsupported type: " + type); + } + + private boolean isComplexType(Type type) { + return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } private Iterable removeNulls(JsonSchemaProperty... properties) { diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index 813f87e7a..cdec3a989 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -1,5 +1,6 @@ package org.acme.example.openai.aiservices; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -26,6 +27,18 @@ public AssistantWithToolsResource(Assistant assistant) { this.assistant = assistant; } + public static class TestData { + String foo; + Integer bar; + Double baz; + + TestData(String foo, Integer bar, Double baz) { + this.foo = foo; + this.bar = bar; + this.baz = baz; + } + } + @GET public String get(@RestQuery String message) { return assistant.chat(message); @@ -54,6 +67,25 @@ int add(int a, int b) { double sqrt(int x) { return Math.sqrt(x); } + + @Tool("Calculates the the sum of all provided numbers") + double sumAll(List x) { + + return x.stream().reduce(0.0, (a, b) -> a + b); + } + + @Tool("Evaluate test data object") + public TestData evaluateTestObject(List data) { + return new TestData("Empty", 0, 0.0); + } + + @Tool("Calculates all factors of the provided integer.") + List getFactors(int x) { + return java.util.stream.IntStream.rangeClosed(1, x) + .filter(i -> x % i == 0) + .boxed() + .toList(); + } } @RequestScoped diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java new file mode 100644 index 000000000..5017569db --- /dev/null +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java @@ -0,0 +1,31 @@ +package org.acme.example.openai.aiservices; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.containsString; + +import java.net.URL; + +import org.junit.jupiter.api.Test; + +import io.quarkus.test.common.http.TestHTTPEndpoint; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class AssistantResourceWithToolsTest { + + @TestHTTPEndpoint(AssistantWithToolsResource.class) + @TestHTTPResource + URL url; + + @Test + public void get() { + given() + .baseUri(url.toString()) + .queryParam("message", "This is a test") + .get() + .then() + .statusCode(200) + .body(containsString("MockGPT")); + } +}