Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] tools enhanced by AOP support #80

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
<version>${spring.boot.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -8,11 +10,20 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass;
import static org.springframework.aop.support.AopUtils.isAopProxy;

class AiServiceFactory implements FactoryBean<Object> {

Expand Down Expand Up @@ -94,7 +105,13 @@ public Object getObject() {
}

if (!isNullOrEmpty(tools)) {
builder = builder.tools(tools);
for (Object tool : tools) {
if (isAopProxy(tool)) {
builder = builder.tools(aopEnhancedTools(tool));
} else {
builder = builder.tools(tool);
}
}
}

return builder.build();
Expand All @@ -120,4 +137,21 @@ public boolean isSingleton() {
* (such as java.io.Closeable.close()) will not be called automatically.
* Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object.
*/

private Map<ToolSpecification, ToolExecutor> aopEnhancedTools(Object enhancedTool) {
Map<ToolSpecification, ToolExecutor> toolExecutors = new HashMap<>();
Class<?> originalToolClass = ultimateTargetClass(enhancedTool);
for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) {
if (originalToolMethod.isAnnotationPresent(Tool.class)) {
Arrays.stream(enhancedTool.getClass().getDeclaredMethods())
.filter(m -> m.getName().equals(originalToolMethod.getName()))
.findFirst()
.ifPresent(enhancedMethod -> {
ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod);
toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod));
});
}
}
return toolExecutors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Set<String> tools = new HashSet<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
if (beanClassName == null) {
continue;
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
}
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {

Expand Down Expand Up @@ -61,6 +69,46 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
});
}

@Test
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.temperature=0.0",
"langchain4j.open-ai.chat-model.log-requests=true",
"langchain4j.open-ai.chat-model.log-responses=true"
)
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class);

// when
String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " +
"And what is the key of the @ToolObserver annotation?" +
"And What is the current time?");

System.out.println("Answer: " + answer);

// then should use AopEnhancedTools.getAspectPackage()
// & AopEnhancedTools.getToolObserverKey()
// & PackagePrivateTools.getCurrentTime()
assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME);
assertThat(answer).contains(TOOL_OBSERVER_KEY);
assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute()));

// and AOP aspect should be called
// & only for getToolObserverKey() which is annotated with @ToolObserver
ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class);
assertTrue(aspect.aspectHasBeenCalled());

assertEquals(1, aspect.getObservedTools().size());
assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION));
assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION));
});
}

// TODO tools which are not @Beans?
// TODO negative cases
// TODO no @AiServices in app, just models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver;
import org.springframework.stereotype.Component;

@Component
public class AopEnhancedTools {

public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION =
"Find the package directory where @ToolObserver is located.";
public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName();

public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION =
"Find the key name of @ToolObserver";
public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122";

@Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)
public String getToolObserverPackageName() {
return TOOL_OBSERVER_PACKAGE_NAME;
}

@ToolObserver(key = TOOL_OBSERVER_KEY)
@Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)
public String getToolObserverKey() {
return TOOL_OBSERVER_KEY;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ToolObserver {

/**
* key just for example
*
* @return the key
*/
String key();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import dev.langchain4j.agent.tool.Tool;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

@Aspect
@Component
public class ToolObserverAspect {

private final List<String> observedTools = new ArrayList<>();

@Around("@annotation(toolObserver)")
public Object around(ProceedingJoinPoint joinPoint, ToolObserver toolObserver) throws Throwable {
var signature = (MethodSignature) joinPoint.getSignature();
var method = signature.getMethod();
String methodName = method.getName();
if (method.isAnnotationPresent(Tool.class)) {
Tool toolAnnotation = method.getAnnotation(Tool.class);
observedTools.addAll(Arrays.asList(toolAnnotation.value()));
System.out.printf("Found @Tool %s for method: %s%n%n", Arrays.toString(toolAnnotation.value()), methodName);
}
Object result = joinPoint.proceed();
System.out.printf(" | key: %s%n | Method name: %s%n | Method arguments: %s%n | Return type: %s%n | Method return value: %s%n%n",
toolObserver.key(),
methodName,
Arrays.toString(joinPoint.getArgs()),
method.getReturnType().getName(),
result);
return result;
}

public boolean aspectHasBeenCalled() {
return !observedTools.isEmpty();
}

public List<String> getObservedTools() {
return observedTools;
}
}