Skip to content

Commit 9cf6633

Browse files
sdeleuzemarkpollack
authored andcommitted
Add support for Kotlin functions
Add support for proper Kotlin functions handling by adapting kotlin.jvm.functions.Function1 to java.util.function.Function and kotlin.jvm.functions.Function2 to java.util.function.BiFunction. It also removes the dependency on Spring Cloud Function and net.jodah:typetools which are replaced by leveraging Spring Framework ResolvableType capabilities. Added withInputType Kotlin extension - Add a Kotlin extension function for FunctionCallbackWrapper.Builder.withInputType allowing to specify withInputType<T>() instead of withInputType(T::class.java). - Add Kotlin documentation
1 parent 298b71c commit 9cf6633

File tree

13 files changed

+721
-119
lines changed

13 files changed

+721
-119
lines changed

spring-ai-core/pom.xml

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,6 @@
5454
<version>${jsonschema.version}</version>
5555
</dependency>
5656

57-
<dependency>
58-
<groupId>org.springframework.cloud</groupId>
59-
<artifactId>spring-cloud-function-context</artifactId>
60-
<version>${spring-cloud-function-context.version}</version>
61-
<exclusions>
62-
<exclusion>
63-
<groupId>org.springframework.boot</groupId>
64-
<artifactId>spring-boot-autoconfigure</artifactId>
65-
</exclusion>
66-
</exclusions>
67-
</dependency>
68-
6957
<!-- production dependencies -->
7058
<dependency>
7159
<groupId>org.antlr</groupId>
@@ -138,6 +126,13 @@
138126
<version>${jackson.version}</version>
139127
</dependency>
140128

129+
<dependency>
130+
<groupId>org.jetbrains.kotlin</groupId>
131+
<artifactId>kotlin-stdlib</artifactId>
132+
<version>${kotlin.version}</version>
133+
<optional>true</optional>
134+
</dependency>
135+
141136
<!-- test dependencies -->
142137
<dependency>
143138
<groupId>org.springframework.boot</groupId>
@@ -146,16 +141,16 @@
146141
</dependency>
147142

148143
<dependency>
149-
<groupId>org.jetbrains.kotlin</groupId>
150-
<artifactId>kotlin-stdlib</artifactId>
151-
<version>${kotlin.version}</version>
144+
<groupId>com.fasterxml.jackson.module</groupId>
145+
<artifactId>jackson-module-kotlin</artifactId>
146+
<version>${jackson.version}</version>
152147
<scope>test</scope>
153148
</dependency>
154149

155150
<dependency>
156-
<groupId>com.fasterxml.jackson.module</groupId>
157-
<artifactId>jackson-module-kotlin</artifactId>
158-
<version>${jackson.version}</version>
151+
<groupId>io.mockk</groupId>
152+
<artifactId>mockk-jvm</artifactId>
153+
<version>1.13.13</version>
159154
<scope>test</scope>
160155
</dependency>
161156

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@
1616

1717
package org.springframework.ai.model.function;
1818

19-
import java.lang.reflect.Type;
2019
import java.util.function.BiFunction;
2120
import java.util.function.Function;
2221

2322
import com.fasterxml.jackson.annotation.JsonClassDescription;
23+
import kotlin.jvm.functions.Function1;
24+
import kotlin.jvm.functions.Function2;
2425

2526
import org.springframework.ai.chat.model.ToolContext;
2627
import org.springframework.beans.BeansException;
27-
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
28-
import org.springframework.cloud.function.context.config.FunctionContextUtils;
28+
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
29+
import org.springframework.beans.factory.config.BeanDefinition;
2930
import org.springframework.context.ApplicationContext;
3031
import org.springframework.context.ApplicationContextAware;
3132
import org.springframework.context.annotation.Description;
3233
import org.springframework.context.support.GenericApplicationContext;
34+
import org.springframework.core.ResolvableType;
3335
import org.springframework.lang.NonNull;
3436
import org.springframework.lang.Nullable;
3537
import org.springframework.util.StringUtils;
@@ -49,6 +51,7 @@
4951
*
5052
* @author Christian Tzolov
5153
* @author Christopher Smith
54+
* @author Sebastien Deleuze
5255
*/
5356
public class FunctionCallbackContext implements ApplicationContextAware {
5457

@@ -68,23 +71,19 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
6871
@SuppressWarnings({ "rawtypes", "unchecked" })
6972
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
7073

71-
Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
72-
73-
if (beanType == null) {
74-
throw new IllegalArgumentException(
75-
"Functional bean with name: " + beanName + " does not exist in the context.");
74+
BeanDefinition beanDefinition;
75+
try {
76+
beanDefinition = this.applicationContext.getBeanDefinition(beanName);
7677
}
77-
78-
if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))
79-
&& !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
78+
catch (NoSuchBeanDefinitionException ex) {
8079
throw new IllegalArgumentException(
81-
"Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName());
80+
"Functional bean with name " + beanName + " does not exist in the context.");
8281
}
8382

84-
Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
83+
ResolvableType functionType = beanDefinition.getResolvableType();
84+
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType.getType(), 0);
8585

86-
Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
87-
String functionName = beanName;
86+
Class<?> functionInputClass = functionInputType.toClass();
8887
String functionDescription = defaultDescription;
8988

9089
if (!StringUtils.hasText(functionDescription)) {
@@ -114,24 +113,40 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
114113

115114
Object bean = this.applicationContext.getBean(beanName);
116115

117-
if (bean instanceof Function<?, ?> function) {
116+
if (KotlinDelegate.isKotlinFunction(functionType.toClass())) {
117+
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinFunction(bean))
118+
.withName(beanName)
119+
.withSchemaType(this.schemaType)
120+
.withDescription(functionDescription)
121+
.withInputType(functionInputClass)
122+
.build();
123+
}
124+
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
125+
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinBiFunction(bean))
126+
.withName(beanName)
127+
.withSchemaType(this.schemaType)
128+
.withDescription(functionDescription)
129+
.withInputType(functionInputClass)
130+
.build();
131+
}
132+
else if (bean instanceof Function<?, ?> function) {
118133
return FunctionCallbackWrapper.builder(function)
119-
.withName(functionName)
134+
.withName(beanName)
120135
.withSchemaType(this.schemaType)
121136
.withDescription(functionDescription)
122137
.withInputType(functionInputClass)
123138
.build();
124139
}
125-
else if (bean instanceof BiFunction<?, ?, ?> biFunction) {
126-
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) biFunction)
127-
.withName(functionName)
140+
else if (bean instanceof BiFunction<?, ?, ?>) {
141+
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
142+
.withName(beanName)
128143
.withSchemaType(this.schemaType)
129144
.withDescription(functionDescription)
130145
.withInputType(functionInputClass)
131146
.build();
132147
}
133148
else {
134-
throw new IllegalArgumentException("Bean must be of type Function");
149+
throw new IllegalStateException();
135150
}
136151
}
137152

@@ -141,4 +156,26 @@ public enum SchemaType {
141156

142157
}
143158

159+
private static class KotlinDelegate {
160+
161+
public static boolean isKotlinFunction(Class<?> clazz) {
162+
return Function1.class.isAssignableFrom(clazz);
163+
}
164+
165+
@SuppressWarnings("unchecked")
166+
public static Function<?, ?> wrapKotlinFunction(Object function) {
167+
return t -> ((Function1<Object, Object>) function).invoke(t);
168+
}
169+
170+
public static boolean isKotlinBiFunction(Class<?> clazz) {
171+
return Function2.class.isAssignableFrom(clazz);
172+
}
173+
174+
@SuppressWarnings("unchecked")
175+
public static BiFunction<?, ToolContext, ?> wrapKotlinBiFunction(Object function) {
176+
return (t, u) -> ((Function2<Object, ToolContext, Object>) function).invoke(t, u);
177+
}
178+
179+
}
180+
144181
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 52 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,22 @@
1616

1717
package org.springframework.ai.model.function;
1818

19-
import java.lang.reflect.GenericArrayType;
20-
import java.lang.reflect.ParameterizedType;
2119
import java.lang.reflect.Type;
2220
import java.util.function.BiFunction;
2321
import java.util.function.Function;
2422

25-
import net.jodah.typetools.TypeResolver;
23+
import kotlin.jvm.functions.Function1;
24+
import kotlin.jvm.functions.Function2;
2625

27-
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
26+
import org.springframework.core.KotlinDetector;
27+
import org.springframework.core.ResolvableType;
2828

2929
/**
3030
* A utility class that provides methods for resolving types and classes related to
3131
* functions.
3232
*
3333
* @author Christian Tzolov
34+
* @author Sebastien Dekeuze
3435
*/
3536
public abstract class TypeResolverHelper {
3637

@@ -68,12 +69,9 @@ public static Class<?> getFunctionOutputClass(Class<? extends Function<?, ?>> fu
6869
* @return The class of the specified function argument.
6970
*/
7071
public static Class<?> getFunctionArgumentClass(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
71-
Type type = TypeResolver.reify(Function.class, functionClass);
72-
73-
var argumentType = type instanceof ParameterizedType
74-
? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class;
75-
76-
return toRawClass(argumentType);
72+
ResolvableType resolvableType = ResolvableType.forClass(functionClass).as(Function.class);
73+
return (resolvableType == ResolvableType.NONE ? Object.class
74+
: resolvableType.getGeneric(argumentIndex).toClass());
7775
}
7876

7977
/**
@@ -84,80 +82,65 @@ public static Class<?> getFunctionArgumentClass(Class<? extends Function<?, ?>>
8482
*/
8583
public static Class<?> getBiFunctionArgumentClass(Class<? extends BiFunction<?, ?, ?>> biFunctionClass,
8684
int argumentIndex) {
87-
Type type = TypeResolver.reify(BiFunction.class, biFunctionClass);
88-
89-
Type argumentType = type instanceof ParameterizedType
90-
? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class;
91-
92-
return toRawClass(argumentType);
93-
}
94-
95-
/**
96-
* Returns the input type of a given function class.
97-
* @param functionClass The class of the function.
98-
* @return The input type of the function.
99-
*/
100-
public static Type getFunctionInputType(Class<? extends Function<?, ?>> functionClass) {
101-
return getFunctionArgumentType(functionClass, 0);
102-
}
103-
104-
/**
105-
* Retrieves the output type of a given function class.
106-
* @param functionClass The function class.
107-
* @return The output type of the function.
108-
*/
109-
public static Type getFunctionOutputType(Class<? extends Function<?, ?>> functionClass) {
110-
return getFunctionArgumentType(functionClass, 1);
85+
ResolvableType resolvableType = ResolvableType.forClass(biFunctionClass).as(BiFunction.class);
86+
return (resolvableType == ResolvableType.NONE ? Object.class
87+
: resolvableType.getGeneric(argumentIndex).toClass());
11188
}
11289

11390
/**
11491
* Retrieves the type of a specific argument in a given function class.
115-
* @param functionClass The function class.
116-
* @param argumentIndex The index of the argument whose type should be retrieved.
117-
* @return The type of the specified function argument.
118-
*/
119-
public static Type getFunctionArgumentType(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
120-
Type functionType = TypeResolver.reify(Function.class, functionClass);
121-
return getFunctionArgumentType(functionType, argumentIndex);
122-
}
123-
124-
/**
125-
* Retrieves the type of a specific argument in a given function type.
12692
* @param functionType The function type.
12793
* @param argumentIndex The index of the argument whose type should be retrieved.
12894
* @return The type of the specified function argument.
95+
* @throws IllegalArgumentException if functionType is not a supported type
12996
*/
130-
public static Type getFunctionArgumentType(Type functionType, int argumentIndex) {
131-
132-
// Resolves: https://github.com/spring-projects/spring-ai/issues/726
133-
if (!(functionType instanceof ParameterizedType)) {
134-
Class<?> functionalClass = FunctionTypeUtils.getRawType(functionType);
135-
// Resolves: https://github.com/spring-projects/spring-ai/issues/1576
136-
if (BiFunction.class.isAssignableFrom(functionalClass)) {
137-
functionType = TypeResolver.reify(BiFunction.class, (Class<BiFunction<?, ?, ?>>) functionalClass);
97+
public static ResolvableType getFunctionArgumentType(Type functionType, int argumentIndex) {
98+
99+
ResolvableType resolvableType = ResolvableType.forType(functionType);
100+
Class<?> resolvableClass = resolvableType.toClass();
101+
ResolvableType functionArgumentResolvableType = ResolvableType.NONE;
102+
103+
if (Function.class.isAssignableFrom(resolvableClass)) {
104+
functionArgumentResolvableType = resolvableType.as(Function.class);
105+
}
106+
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
107+
functionArgumentResolvableType = resolvableType.as(BiFunction.class);
108+
}
109+
else if (KotlinDetector.isKotlinPresent()) {
110+
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
111+
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(resolvableType);
138112
}
139-
else {
140-
functionType = FunctionTypeUtils.discoverFunctionTypeFromClass(functionalClass);
113+
else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
114+
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(resolvableType);
141115
}
142116
}
143117

144-
var argumentType = functionType instanceof ParameterizedType
145-
? ((ParameterizedType) functionType).getActualTypeArguments()[argumentIndex] : Object.class;
118+
if (functionArgumentResolvableType == ResolvableType.NONE) {
119+
throw new IllegalArgumentException(
120+
"Type must be a Function, BiFunction, Function1 or Function2. Found: " + resolvableType);
121+
}
146122

147-
return argumentType;
123+
return functionArgumentResolvableType.getGeneric(argumentIndex);
148124
}
149125

150-
/**
151-
* Effectively converts {@link Type} which could be {@link ParameterizedType} to raw
152-
* Class (no generics).
153-
* @param type actual {@link Type} instance
154-
* @return instance of {@link Class} as raw representation of the provided
155-
* {@link Type}
156-
*/
157-
public static Class<?> toRawClass(Type type) {
158-
return type != null
159-
? TypeResolver.resolveRawClass(type instanceof GenericArrayType ? type : TypeResolver.reify(type), null)
160-
: null;
126+
private static class KotlinDelegate {
127+
128+
public static boolean isKotlinFunction(Class<?> clazz) {
129+
return Function1.class.isAssignableFrom(clazz);
130+
}
131+
132+
public static ResolvableType adaptToKotlinFunctionType(ResolvableType resolvableType) {
133+
return resolvableType.as(Function1.class);
134+
}
135+
136+
public static boolean isKotlinBiFunction(Class<?> clazz) {
137+
return Function2.class.isAssignableFrom(clazz);
138+
}
139+
140+
public static ResolvableType adaptToKotlinBiFunctionType(ResolvableType resolvableType) {
141+
return resolvableType.as(Function2.class);
142+
}
143+
161144
}
162145

163146
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.function
18+
19+
/**
20+
* Extension for [FunctionCallbackWrapper.Builder.withInputType] providing a `withInputType<Foo>()`
21+
* variant.
22+
*
23+
* @author Sebastien Deleuze
24+
*/
25+
inline fun <reified T> FunctionCallbackWrapper.Builder<*, *>.withInputType() =
26+
withInputType(T::class.java)

0 commit comments

Comments
 (0)