diff --git a/langchain4j-spring-boot-starter/pom.xml b/langchain4j-spring-boot-starter/pom.xml index 30440842..7c721efa 100644 --- a/langchain4j-spring-boot-starter/pom.xml +++ b/langchain4j-spring-boot-starter/pom.xml @@ -53,6 +53,13 @@ test + + org.springframework.boot + spring-boot-starter-aop + ${spring.boot.version} + test + + dev.langchain4j langchain4j-core diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java index e2a2bdfc..d5063192 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java @@ -1,5 +1,7 @@ package dev.langchain4j.service.spring; +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -8,11 +10,20 @@ import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.tool.DefaultToolExecutor; +import dev.langchain4j.service.tool.ToolExecutor; import org.springframework.beans.factory.FactoryBean; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom; import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass; +import static org.springframework.aop.support.AopUtils.isAopProxy; class AiServiceFactory implements FactoryBean { @@ -94,7 +105,13 @@ public Object getObject() { } if (!isNullOrEmpty(tools)) { - builder = builder.tools(tools); + for (Object tool : tools) { + if (isAopProxy(tool)) { + builder = builder.tools(aopEnhancedTools(tool)); + } else { + builder = builder.tools(tool); + } + } } return builder.build(); @@ -120,4 +137,21 @@ public boolean isSingleton() { * (such as java.io.Closeable.close()) will not be called automatically. * Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object. */ + + private Map aopEnhancedTools(Object enhancedTool) { + Map toolExecutors = new HashMap<>(); + Class originalToolClass = ultimateTargetClass(enhancedTool); + for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) { + if (originalToolMethod.isAnnotationPresent(Tool.class)) { + Arrays.stream(enhancedTool.getClass().getDeclaredMethods()) + .filter(m -> m.getName().equals(originalToolMethod.getName())) + .findFirst() + .ifPresent(enhancedMethod -> { + ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod); + toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod)); + }); + } + } + return toolExecutors; + } } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 90567db3..69f6fddf 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -49,7 +49,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { Set tools = new HashSet<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { try { - Class beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName()); + String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName(); + if (beanClassName == null) { + continue; + } + Class beanClass = Class.forName(beanClassName); for (Method beanMethod : beanClass.getDeclaredMethods()) { if (beanMethod.isAnnotationPresent(Tool.class)) { tools.add(beanName); diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java index 3cd8e1db..49bf50b5 100644 --- a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java @@ -1,14 +1,22 @@ package dev.langchain4j.service.spring.mode.automatic.withTools; import dev.langchain4j.service.spring.AiServicesAutoConfig; +import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY; +import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY; +import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION; +import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME; +import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION; import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME; import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; class AiServicesAutoConfigIT { @@ -61,6 +69,46 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag }); } + @Test + void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY, + "langchain4j.open-ai.chat-model.temperature=0.0", + "langchain4j.open-ai.chat-model.log-requests=true", + "langchain4j.open-ai.chat-model.log-responses=true" + ) + .withUserConfiguration(AiServiceWithToolsApplication.class) + .run(context -> { + + // given + AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class); + + // when + String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " + + "And what is the key of the @ToolObserver annotation?" + + "And What is the current time?"); + + System.out.println("Answer: " + answer); + + // then should use AopEnhancedTools.getAspectPackage() + // & AopEnhancedTools.getToolObserverKey() + // & PackagePrivateTools.getCurrentTime() + assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME); + assertThat(answer).contains(TOOL_OBSERVER_KEY); + assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute())); + + // and AOP aspect should be called + // & only for getToolObserverKey() which is annotated with @ToolObserver + ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class); + assertTrue(aspect.aspectHasBeenCalled()); + + assertEquals(1, aspect.getObservedTools().size()); + assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)); + assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)); + }); + } + // TODO tools which are not @Beans? // TODO negative cases // TODO no @AiServices in app, just models diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AopEnhancedTools.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AopEnhancedTools.java new file mode 100644 index 00000000..4db600d7 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AopEnhancedTools.java @@ -0,0 +1,28 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver; +import org.springframework.stereotype.Component; + +@Component +public class AopEnhancedTools { + + public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION = + "Find the package directory where @ToolObserver is located."; + public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName(); + + public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION = + "Find the key name of @ToolObserver"; + public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122"; + + @Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION) + public String getToolObserverPackageName() { + return TOOL_OBSERVER_PACKAGE_NAME; + } + + @ToolObserver(key = TOOL_OBSERVER_KEY) + @Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION) + public String getToolObserverKey() { + return TOOL_OBSERVER_KEY; + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserver.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserver.java new file mode 100644 index 00000000..0c95365f --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserver.java @@ -0,0 +1,18 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.aop; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +public @interface ToolObserver { + + /** + * key just for example + * + * @return the key + */ + String key(); +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserverAspect.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserverAspect.java new file mode 100644 index 00000000..68eb013d --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/aop/ToolObserverAspect.java @@ -0,0 +1,47 @@ +package dev.langchain4j.service.spring.mode.automatic.withTools.aop; + +import dev.langchain4j.agent.tool.Tool; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +@Aspect +@Component +public class ToolObserverAspect { + + private final List observedTools = new ArrayList<>(); + + @Around("@annotation(toolObserver)") + public Object around(ProceedingJoinPoint joinPoint, ToolObserver toolObserver) throws Throwable { + var signature = (MethodSignature) joinPoint.getSignature(); + var method = signature.getMethod(); + String methodName = method.getName(); + if (method.isAnnotationPresent(Tool.class)) { + Tool toolAnnotation = method.getAnnotation(Tool.class); + observedTools.addAll(Arrays.asList(toolAnnotation.value())); + System.out.printf("Found @Tool %s for method: %s%n%n", Arrays.toString(toolAnnotation.value()), methodName); + } + Object result = joinPoint.proceed(); + System.out.printf(" | key: %s%n | Method name: %s%n | Method arguments: %s%n | Return type: %s%n | Method return value: %s%n%n", + toolObserver.key(), + methodName, + Arrays.toString(joinPoint.getArgs()), + method.getReturnType().getName(), + result); + return result; + } + + public boolean aspectHasBeenCalled() { + return !observedTools.isEmpty(); + } + + public List getObservedTools() { + return observedTools; + } +}