Skip to content

Support for declarative build function callbacks in the Spring container #1105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
/**
* A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve
* a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}.
*
* <p>
* The name of the function is determined by the bean name.
*
* <p>
* The description of the function is determined by the following rules:
* <ul>
* <li>Provided as a default description</li>
Expand Down Expand Up @@ -73,7 +73,13 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
"Functional bean with name: " + beanName + " does not exist in the context.");
}

if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
Class<?> beanClass = FunctionTypeUtils.getRawType(beanType);

if (FunctionCallback.class.isAssignableFrom(beanClass)){
return (FunctionCallback) applicationContext.getBean(beanName);
}

if (!Function.class.isAssignableFrom(beanClass)) {
throw new IllegalArgumentException(
"Function call Bean must be of type Function. Found: " + beanType.getTypeName());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.function;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
import org.springframework.aop.scope.ScopedObject;
import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* {@link BeanFactoryPostProcessor} that processes {@link FunctionCalling} annotations.
* <p>
* <p>Any such annotated method is registered as a {@link MethodFunctionCallback} bean in the
* application context.
* <p>
* <p>Processing of {@code @FunctionCalling} annotations can be customized through the
* {@link FunctionCalling} annotation.
* <p>
*
* @see FunctionCalling
* @see MethodFunctionCallback
* @author kamosama
*/
public class FunctionCallbackMethodProcessor
implements SmartInitializingSingleton, BeanFactoryPostProcessor {

protected final Log logger = LogFactory.getLog(getClass());

private final Set<Class<?>> nonAnnotatedClasses = Collections.newSetFromMap(new ConcurrentHashMap<>(64));

@Nullable
private ConfigurableListableBeanFactory beanFactory;


@Override
public void afterSingletonsInstantiated() {
Assert.state(this.beanFactory != null, "No ConfigurableListableBeanFactory set");

String[] beanNames = beanFactory.getBeanNamesForType(Object.class);

for (String beanName : beanNames) {
if (ScopedProxyUtils.isScopedTarget(beanName)) {
continue;
}
Class<?> type = null;
try {
type = AutoProxyUtils.determineTargetClass(beanFactory, beanName);
} catch (Throwable ex) {
// An unresolvable bean type, probably from a lazy bean - let's ignore it.
if (logger.isDebugEnabled()) {
logger.debug("Could not resolve target class for bean with name '" + beanName + "'", ex);
}
}
if (type == null) {
continue;
}
if (ScopedObject.class.isAssignableFrom(type)) {
try {
Class<?> targetClass = AutoProxyUtils.determineTargetClass(
beanFactory, ScopedProxyUtils.getTargetBeanName(beanName));
if (targetClass != null) {
type = targetClass;
}
} catch (Throwable ex) {
// An invalid scoped proxy arrangement - let's ignore it.
if (logger.isDebugEnabled()) {
logger.debug("Could not resolve target bean for scoped proxy '" + beanName + "'", ex);
}
}
}
try {
processBean(beanName, type);
} catch (Throwable ex) {
throw new BeanInitializationException("Failed to process @FunctionCalling " +
"annotation on bean with name '" + beanName + "'", ex);
}
}
}

private void processBean(final String beanName, final Class<?> targetType) {
Assert.state(this.beanFactory != null, "No ConfigurableListableBeanFactory set");

if (!this.nonAnnotatedClasses.contains(targetType)
&& AnnotationUtils.isCandidateClass(targetType, FunctionCalling.class)
&& !isSpringContainerClass(targetType)
) {

Map<Method, FunctionCalling> annotatedMethods = null;
try {
annotatedMethods = MethodIntrospector.selectMethods(targetType,
(MethodIntrospector.MetadataLookup<FunctionCalling>) method ->
AnnotatedElementUtils.findMergedAnnotation(method, FunctionCalling.class));
} catch (Throwable ex) {
// An unresolvable type in a method signature, probably from a lazy bean - let's ignore it.
if (logger.isDebugEnabled()) {
logger.debug("Could not resolve methods for bean with name '" + beanName + "'", ex);
}
}

if (CollectionUtils.isEmpty(annotatedMethods)) {
this.nonAnnotatedClasses.add(targetType);
if (logger.isTraceEnabled()) {
logger.trace("No @FunctionCalling annotations found on bean class: " + targetType.getName());
}
} else {
// Non-empty set of methods
annotatedMethods.forEach((method, annotation) -> {
String name = annotation.name().isEmpty() ? method.getName() : annotation.name();
ReflectionUtils.makeAccessible(method);
var functionObject = Modifier.isStatic(method.getModifiers()) ? null : beanFactory.getBean(beanName);
MethodFunctionCallback callback = MethodFunctionCallback.builder()
.withFunctionObject(functionObject)
.withMethod(method)
.withDescription(annotation.description())
.build();
beanFactory.registerSingleton(name, callback);
});

if (logger.isDebugEnabled()) {
logger.debug(annotatedMethods.size() + " @FunctionCalling methods processed on bean '" +
beanName + "': " + annotatedMethods);
}
}
}
}

/**
* Determine whether the given class is an {@code org.springframework}
* bean class that is not annotated as a user or test {@link Component}...
* which indicates that there is no {@link FunctionCalling} to be found there.
*/
private static boolean isSpringContainerClass(Class<?> clazz) {
return (clazz.getName().startsWith("org.springframework.") &&
!AnnotatedElementUtils.isAnnotated(ClassUtils.getUserClass(clazz), Component.class));
}

@Override
public void postProcessBeanFactory(@NonNull ConfigurableListableBeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.function;

import java.lang.annotation.*;

/**
* Annotation to indicate that a method is a AI function calling.
*
* @see FunctionCallbackMethodProcessor
* @author kamosama
*/
@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface FunctionCalling {

String name() default "";

String description();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.springframework.ai.model.function;

import org.springframework.context.annotation.Configuration;

import java.time.LocalDateTime;

@Configuration
public class FunctionCallConfig {

@FunctionCalling(name = "dateTime", description = "get the current date and time")
public String dateTime(String location) {
return location + " dateTime:" + LocalDateTime.now();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.function;

import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

public class FunctionCallbackMethodProcessorIT {

private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackMethodProcessorIT.class);


private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withUserConfiguration(FunctionCallConfig.class)
.withBean(FunctionCallbackContext.class)
.withBean(FunctionCallbackMethodProcessor.class);

@Test
public void testFunctionCallbackMethodProcessor() {
contextRunner.run(context -> {
FunctionCallbackContext callbackContext = context.getBean(FunctionCallbackContext.class);
FunctionCallback functionCallback = callbackContext.getFunctionCallback("dateTime", null);
logger.info("FunctionCallback: name:{}, description:{}",
functionCallback.getName(), functionCallback.getDescription());
String result = functionCallback.call("{\"location\":\"New York\"}");
logger.info("Result: {}", result);
assert result.contains("New York");
});
}

}
Loading