Skip to content

Commit

Permalink
ChatClient register functions with explicit input type
Browse files Browse the repository at this point in the history
 The Lambda functions do not retain the type information, so we need to provide the input type explicitly.

 Resolves #1052

 Co-authored-by: liuzhifei <2679431923@qq.com>
  • Loading branch information
tzolov committed Jul 23, 2024
1 parent 03d1d50 commit 6270d62
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.lang.reflect.Method;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.DefaultChatClient;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
Expand Down Expand Up @@ -123,4 +126,47 @@ void streamFunctionCallTest() {

}

@Test
void functionCallWithExplicitInputType() throws NoSuchMethodException {

var chatClient = ChatClient.create(chatModel);

Method currentTemp = MyFunction.class.getMethod("getCurrentTemp", MyFunction.Req.class);

// NOTE: Lambda functions do not retain the type information, so we need to
// provide the input type explicitly.
MyFunction myFunction = new MyFunction();
Function<MyFunction.Req, Object> function = createFunction(myFunction, currentTemp);

ChatClient.ChatClientRequestSpec chatClientRequestSpec = chatClient.prompt()
.user("What's the weather like in Shanghai?")
.function("currentTemp", "get current temp", MyFunction.Req.class, function);

String content = chatClientRequestSpec.call().content();

assertThat(content).contains("23");
}

public static <T, R> Function<T, R> createFunction(Object obj, Method method) {
return (T t) -> {
try {
return (R) method.invoke(obj, t);
}
catch (Exception e) {
throw new RuntimeException(e);
}
};
}

public static class MyFunction {

public record Req(String city) {
}

public String getCurrentTemp(Req req) {
return "23";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ interface ChatClientRequestSpec {
<I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function);

<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
java.util.function.Function<I, O> function);

ChatClientRequestSpec functions(String... functionBeanNames);

ChatClientRequestSpec system(String text);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,11 @@ public <T extends ChatOptions> ChatClientRequestSpec options(T options) {

public <I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function) {
return this.function(name, description, null, function);
}

public <I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
java.util.function.Function<I, O> function) {

Assert.hasText(name, "the name must be non-null and non-empty");
Assert.hasText(description, "the description must be non-null and non-empty");
Expand All @@ -618,6 +623,7 @@ public <I, O> ChatClientRequestSpec function(String name, String description,
var fcw = FunctionCallbackWrapper.builder(function)
.withDescription(description)
.withName(name)
.withInputType(inputType)
.withResponseConverter(Object::toString)
.build();
this.functionCallbacks.add(fcw);
Expand Down

0 comments on commit 6270d62

Please sign in to comment.