Skip to content

Commit

Permalink
Adjusted method signature mapping to json schema, to allow collection…
Browse files Browse the repository at this point in the history
…s in tool arguments resolves quarkiverse#671
  • Loading branch information
Tarjei400 committed Nov 3, 2024
1 parent ed12693 commit dfa8052
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,11 @@
import static java.util.stream.Collectors.toList;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;
import org.jboss.jandex.*;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Opcodes;
Expand Down Expand Up @@ -200,9 +187,9 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem,
AnnotationInstance pInstance = parameter.annotation(P);
if (pInstance != null && pInstance.value("required") != null
&& !pInstance.value("required").asBoolean()) {
builder.addOptionalParameter(parameter.name(), toJsonSchemaProperties(parameter, index));
builder.addOptionalParameter(parameter.name(), toJsonSchemaProperties2(parameter, index));
} else {
builder.addParameter(parameter.name(), toJsonSchemaProperties(parameter, index));
builder.addParameter(parameter.name(), toJsonSchemaProperties2(parameter, index));
}
}

Expand Down Expand Up @@ -438,6 +425,21 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo
if ((type.kind() == Type.Kind.ARRAY)
|| DotNames.LIST.equals(typeName)
|| DotNames.SET.equals(typeName)) { // TODO something else?
// Handle the case where List or Set has a parameterized type
//if (type.kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = type.asParameterizedType();
if (parameterizedType != null) {
Type elementType = parameterizedType.arguments().get(0); // Get the first generic type argument

// Now you can process `elementType`, e.g., recursively call `toJsonSchemaProperties` if needed
//Iterable<JsonSchemaProperty> elementProperties = toJsonSchemaProperties(elementType, index);

return removeNulls(ARRAY, description);
}

// return removeNulls(ARRAY, description, elementProperties); // Include element type information
// }

return removeNulls(ARRAY, description); // TODO provide type of array?
}

Expand All @@ -448,6 +450,124 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo
return removeNulls(OBJECT, description); // TODO provide internals
}

private Iterable<JsonSchemaProperty> toJsonSchemaProperties2(MethodParameterInfo parameter, IndexView index) {
// Retrieve type and annotation information
Type type = parameter.type();
AnnotationInstance pInstance = parameter.annotation(P);

// Extract description if the @P annotation is present
JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString());

// Call recursive helper with extracted Type and description
return toJsonSchemaProperties2(type, index, description);
}

// Recursive helper method that accepts Type and description
private Iterable<JsonSchemaProperty> toJsonSchemaProperties2(Type type, IndexView index, JsonSchemaProperty description) {
DotName typeName = type.name();

// Handle wildcard types by resolving to their bounds
if (type.kind() == Type.Kind.WILDCARD_TYPE) {
Type boundType = type.asWildcardType().extendsBound();
if (boundType == null) {
// If there is no extends bound, check for a super bound
boundType = type.asWildcardType().superBound();
}
if (boundType != null) {
// Recursively call toJsonSchemaProperties2 on the resolved bound type
return toJsonSchemaProperties2(boundType, index, description);
} else {
throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type);
}
}
// Handle basic types
if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName)
|| DotNames.PRIMITIVE_CHAR.equals(typeName)) {
return removeNulls(STRING, description);
}

if (DotNames.BOOLEAN.equals(typeName) || DotNames.PRIMITIVE_BOOLEAN.equals(typeName)) {
return removeNulls(BOOLEAN, description);
}

if (DotNames.BYTE.equals(typeName) || DotNames.PRIMITIVE_BYTE.equals(typeName)
|| DotNames.SHORT.equals(typeName) || DotNames.PRIMITIVE_SHORT.equals(typeName)
|| DotNames.INTEGER.equals(typeName) || DotNames.PRIMITIVE_INT.equals(typeName)
|| DotNames.LONG.equals(typeName) || DotNames.PRIMITIVE_LONG.equals(typeName)
|| DotNames.BIG_INTEGER.equals(typeName)) {
return removeNulls(INTEGER, description);
}

if (DotNames.FLOAT.equals(typeName) || DotNames.PRIMITIVE_FLOAT.equals(typeName)
|| DotNames.DOUBLE.equals(typeName) || DotNames.PRIMITIVE_DOUBLE.equals(typeName)
|| DotNames.BIG_DECIMAL.equals(typeName)) {
return removeNulls(NUMBER, description);
}

// Handle collections or arrays recursively
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
: null;

// For parameterized collections, get the element type
Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0)
: type.asArrayType().component();

// Recursively get the schema for the element type
Iterable<JsonSchemaProperty> elementProperties = toJsonSchemaProperties2(elementType, index, null);

// Use items or objectItems based on whether element type is simple or complex
JsonSchemaProperty itemsSchema;
if (isComplexType(elementType)) {
itemsSchema = JsonSchemaProperty.objectItems(elementProperties.iterator().next());
} else {
itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next());
}

return removeNulls(ARRAY, itemsSchema, description);
}

// Handle enums
if (isEnum(type, index)) {
return removeNulls(STRING, enums(enumConstants(type)), description);
}

// Handle complex objects recursively by processing their fields
if (type.kind() == Type.Kind.CLASS) {
Map<String, Object> properties = new HashMap<>();
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();
if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties2(field.type(), index, null);
Iterator<JsonSchemaProperty> iterator = fieldSchema.iterator();
JsonSchemaProperty fieldProperty;
Map<String, Object> fieldDescription = new HashMap<>();
while (iterator.hasNext()) {
fieldProperty = iterator.next();
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
}

properties.put(fieldName, fieldDescription);
}
}

// Create a JsonSchemaProperty for the object with its properties
JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties);
return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description);
}

throw new IllegalArgumentException("Unsupported type: " + type);
}

// Utility method to determine if a type is complex (e.g., an object or parameterized)
private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
return stream(properties)
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.acme.example.openai.aiservices;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -26,6 +27,18 @@ public AssistantWithToolsResource(Assistant assistant) {
this.assistant = assistant;
}

class TestData {
String foo;
Integer bar;
Double baz;

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
}
}

@GET
public String get(@RestQuery String message) {
return assistant.chat(message);
Expand Down Expand Up @@ -54,6 +67,26 @@ int add(int a, int b) {
double sqrt(int x) {
return Math.sqrt(x);
}

@Tool("Calculates the the sum of all provided numbers")
double sumAll(List<Double> x) {

return x.stream().reduce(0.0, (a, b) -> a + b);
}

@Tool("Evaluate test data object")
TestData evaluateTestObject(TestData data) {

return data;
}

@Tool("Calculates all factors of the provided integer.")
List<Integer> getFactors(int x) {
return java.util.stream.IntStream.rangeClosed(1, x)
.filter(i -> x % i == 0)
.boxed()
.toList();
}
}

@RequestScoped
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.acme.example.openai.aiservices;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;

import java.net.URL;

import org.junit.jupiter.api.Test;

import io.quarkus.test.common.http.TestHTTPEndpoint;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.test.junit.QuarkusTest;

@QuarkusTest
public class AssistantResourceWithToolsTest {

@TestHTTPEndpoint(AssistantWithToolsResource.class)
@TestHTTPResource
URL url;

@Test
public void get() {
given()
.baseUri(url.toString())
.queryParam("message", "This is a test")
.get()
.then()
.statusCode(200)
.body(containsString("MockGPT"));
}
}

0 comments on commit dfa8052

Please sign in to comment.