Skip to content

Commit

Permalink
feat(core): Add Supplier and Consumer function callback support
Browse files Browse the repository at this point in the history
Add support for no-argument Supplier and single-argument Consumer function
callbacks in the Spring AI core module. This enhancement allows:
- Registration of Supplier<O> callbacks with no input (Void) type
- Registration of Consumer<I> callbacks with no output (Void) type
- Support for Kotlin Function0 (equivalent to Java Supplier)
- Handle empty properties for Void input types in schema generation
- Enhance FunctionCallback builder to support Supplier/Consumer patterns

Additional changes:
- Add test coverage for both Supplier and Consumer callbacks in various scenarios
- Enhance TypeResolverHelper to support Consumer input type resolution
- Support lambda-style function declarations for improved ergonomics
- Add test cases for void input/output handling in OpenAI chat model
- Include examples of function calls without return values
- Add support for parameterless functions through Supplier interface

Resolves #1718 , #1277 , #1118, #860
  • Loading branch information
tzolov committed Nov 16, 2024
1 parent 018257a commit e7cc220
Show file tree
Hide file tree
Showing 11 changed files with 466 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

Expand All @@ -28,6 +29,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
Expand Down Expand Up @@ -59,6 +61,25 @@ class OpenAiChatModelFunctionCallingIT {
@Autowired
ChatModel chatModel;

@Test
void functionCallSupplier() {

Map<String, Object> state = new ConcurrentHashMap<>();

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn the light on in the living room")
.functions(FunctionCallback.builder()
.function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON"))
.build())
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);
assertThat(state).containsEntry("Light", "ON");
}

@Test
void functionCallTest() {
functionCallTest(OpenAiChatOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues
}

ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(inputType);

if ((inputType == Void.class) && !node.has("properties")) {
node.putObject("properties");
}

if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
// version of it).
toUpperCaseTypeValues(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
Expand All @@ -43,7 +45,7 @@

/**
* Default implementation of the {@link FunctionCallback.Builder}.
*
*
* @author Christian Tzolov
* @since 1.0.0
*/
Expand Down Expand Up @@ -137,6 +139,20 @@ public <I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, Too
return new DefaultFunctionInvokingSpec<>(name, biFunction);
}

@Override
public <O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
Function<Void, O> function = (input) -> supplier.get();
return new DefaultFunctionInvokingSpec<>(name, function).inputType(Void.class);
}

public <I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
Function<I, Void> function = (I input) -> {
consumer.accept(input);
return null;
};
return new DefaultFunctionInvokingSpec<>(name, function);
}

@Override
public MethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
return new DefaultMethodInvokingSpec(methodName, argumentTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.databind.ObjectMapper;

Expand Down Expand Up @@ -141,6 +143,16 @@ interface Builder {
*/
<I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, ToolContext, O> biFunction);

/**
* Builds a {@link Supplier} invoking {@link FunctionCallback} instance.
*/
<O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier);

/**
* Builds a {@link Consumer} invoking {@link FunctionCallback} instance.
*/
<I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer);

/**
* Builds a Method invoking {@link FunctionCallback} instance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;

Expand All @@ -30,6 +33,7 @@
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -71,7 +75,8 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {

ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);
ResolvableType functionInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(functionType))
? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(functionType, 0);

Class<?> functionInputClass = functionInputType.toClass();
String functionDescription = defaultDescription;
Expand Down Expand Up @@ -109,15 +114,23 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
.schemaType(this.schemaType)
.description(functionDescription)
.function(beanName, KotlinDelegate.wrapKotlinFunction(bean))
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, KotlinDelegate.wrapKotlinBiFunction(bean))
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (KotlinDelegate.isKotlinSupplier(functionType.toClass())) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, KotlinDelegate.wrapKotlinSupplier(bean))
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
}
Expand All @@ -126,15 +139,31 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
.schemaType(this.schemaType)
.description(functionDescription)
.function(beanName, function)
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof BiFunction<?, ?, ?>) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, (BiFunction<?, ToolContext, ?>) bean)
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof Supplier<?> supplier) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, supplier)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof Consumer<?> consumer) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, consumer)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else {
Expand All @@ -150,6 +179,15 @@ public enum SchemaType {

private static class KotlinDelegate {

public static boolean isKotlinSupplier(Class<?> clazz) {
return Function0.class.isAssignableFrom(clazz);
}

@SuppressWarnings("unchecked")
public static Supplier<?> wrapKotlinSupplier(Object function) {
return () -> ((Function0<Object>) function).invoke();
}

public static boolean isKotlinFunction(Class<?> clazz) {
return Function1.class.isAssignableFrom(clazz);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;

Expand All @@ -44,6 +47,16 @@
*/
public abstract class TypeResolverHelper {

/**
* Returns the input class of a given Consumer class.
* @param consumerClass The consumer class.
* @return The input class of the consumer.
*/
public static Class<?> getConsumerInputClass(Class<? extends Consumer<?>> consumerClass) {
ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class);
return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass());
}

/**
* Returns the input class of a given function class.
* @param biFunctionClass The function class.
Expand Down Expand Up @@ -199,13 +212,22 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(BiFunction.class);
}
else if (Supplier.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(Supplier.class);
}
else if (Consumer.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(Consumer.class);
}
else if (KotlinDetector.isKotlinPresent()) {
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType);
}
else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(functionType);
}
else if (KotlinDelegate.isKotlinSupplier(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinSupplierType(functionType);
}
}

if (functionArgumentResolvableType == ResolvableType.NONE) {
Expand All @@ -218,6 +240,14 @@ else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {

private static class KotlinDelegate {

public static boolean isKotlinSupplier(Class<?> clazz) {
return Function0.class.isAssignableFrom(clazz);
}

public static ResolvableType adaptToKotlinSupplierType(ResolvableType resolvableType) {
return resolvableType.as(Function0.class);
}

public static boolean isKotlinFunction(Class<?> clazz) {
return Function1.class.isAssignableFrom(clazz);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.function;

import java.util.function.Consumer;
import java.util.function.Function;

import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -39,7 +40,7 @@ public class TypeResolverHelperIT {

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction",
"scannedStandaloneWeatherFunction", "componentWeatherFunction" })
"scannedStandaloneWeatherFunction", "componentWeatherFunction", "weatherConsumer" })
void beanInputTypeResolutionWithResolvableType(String beanName) {
assertThat(this.applicationContext).isNotNull();
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
Expand Down Expand Up @@ -89,6 +90,13 @@ StandaloneWeatherFunction standaloneWeatherFunction() {
return new StandaloneWeatherFunction();
}

@Bean
Consumer<WeatherRequest> weatherConsumer() {
return (weatherRequest) -> {
System.out.println(weatherRequest);
};
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.function;

import java.util.function.Consumer;
import java.util.function.Function;

import com.fasterxml.jackson.annotation.JsonClassDescription;
Expand All @@ -35,6 +36,12 @@
*/
public class TypeResolverHelperTests {

@Test
public void testGetConsumerInputType() {
Class<?> inputType = TypeResolverHelper.getConsumerInputClass(MyConsumer.class);
assertThat(inputType).isEqualTo(Request.class);
}

@Test
public void testGetFunctionInputType() {
Class<?> inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class);
Expand Down Expand Up @@ -63,6 +70,14 @@ public String apply(Response response) {

}

public static class MyConsumer implements Consumer<Request> {

@Override
public void accept(Request request) {
}

}

public static class MockWeatherService implements Function<Request, Response> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void functionCallTest() {
var promptOptions = MistralAiChatOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.description("Get payment status of a transaction")
.function("retrievePaymentStatus", transaction -> new Status(DATA.get(transaction).status()))
.function("retrievePaymentStatus",
(Transaction transaction) -> new Status(DATA.get(transaction).status()))
.inputType(Transaction.class)
.build()))
.build();
Expand Down
Loading

0 comments on commit e7cc220

Please sign in to comment.