Skip to content

Commit

Permalink
feat: Add support for Consumer and Supplier function callbacks
Browse files Browse the repository at this point in the history
- Add support for Java Consumer and Supplier functional interfaces in function callbacks
- Handle void type inputs and outputs in function callbacks
- Add test cases for void responses, Consumer callbacks, and Supplier callbacks
- Update ModelOptionsUtils to properly handle void type schemas

Resolves spring-projects#1718 and spring-projects#1277
  • Loading branch information
tzolov committed Nov 11, 2024
1 parent 5e86583 commit bc3514d
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.BeanWrapperImpl;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;

Expand Down Expand Up @@ -358,8 +359,13 @@ public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues
}

ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz);
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
// version of it).

if (ClassUtils.isVoidType(clazz) && node.get("properties") == null) {
node.putObject("properties");
}

// Required for OpenAPI 3.0 (at least Vertex AI version of it).
if (toUpperCaseTypeValues) {
toUpperCaseTypeValues(node);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import kotlin.jvm.functions.Function1;
Expand Down Expand Up @@ -129,6 +131,22 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
.withInputType(functionInputClass)
.build();
}
if (bean instanceof Consumer<?> consumer) {
return FunctionCallbackWrapper.builder(consumer)
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
if (bean instanceof Supplier<?> supplier) {
return FunctionCallbackWrapper.builder(supplier)
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else if (bean instanceof BiFunction<?, ?, ?>) {
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
.withName(beanName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -59,6 +61,19 @@ public static <I, O> Builder<I, O> builder(Function<I, O> function) {
return new Builder<>(function);
}

public static <Void, O> Builder<Void, O> builder(Supplier<O> supplier) {
Function<Void, O> function = (input) -> supplier.get();
return new Builder<>(function);
}

public static <I, Void> Builder<I, Void> builder(Consumer<I> consumer) {
Function<I, Void> function = (input) -> {
consumer.accept(input);
return null;
};
return new Builder<>(function);
}

@Override
public O apply(I input, ToolContext context) {
return this.biFunction.apply(input, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
Expand Down Expand Up @@ -189,6 +191,12 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(BiFunction.class);
}
else if (Supplier.class.isAssignableFrom(resolvableClass)) {
return ResolvableType.forClass(Void.class);
}
else if (Consumer.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(Consumer.class);
}
else if (KotlinDetector.isKotlinPresent()) {
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType);
Expand Down
Loading

0 comments on commit bc3514d

Please sign in to comment.