diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 036a898bb..9339c4178 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -31,6 +31,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -124,7 +125,8 @@ public Prompt toLlmPrompt(LlmRequest llmRequest) { List toolCallbacks = toolConverter.convertToSpringAiTools(llmRequest.tools()); if (!toolCallbacks.isEmpty()) { // Create new ChatOptions with tools included - ToolCallingChatOptions.Builder optionsBuilder = ToolCallingChatOptions.builder(); + ToolCallingChatOptions.Builder optionsBuilder = + ToolCallingChatOptions.builder().internalToolExecutionEnabled(false); // Always set tool callbacks optionsBuilder.toolCallbacks(toolCallbacks); @@ -204,10 +206,26 @@ private List handleUserContent(Content content) { if (part.text().isPresent()) { textBuilder.append(part.text().get()); } else if (part.functionResponse().isPresent()) { - // TODO: Spring AI 1.1.0 ToolResponseMessage constructors are protected - // For now, we skip tool responses in user messages - // This will need to be addressed in a future update when Spring AI provides - // a public API for creating ToolResponseMessage + var functionResponse = part.functionResponse().get(); + functionResponse + .id() + .ifPresent( + id -> + functionResponse + .name() + .ifPresent( + name -> + functionResponse + .response() + .ifPresent( + response -> + toolResponseMessages.add( + ToolResponseMessage.builder() + .responses( + List.of( + new ToolResponse( + id, name, toJson(response)))) + .build())))); } else if (part.inlineData().isPresent()) { // Handle inline media data (images, audio, video, etc.) com.google.genai.types.Blob blob = part.inlineData().get(); @@ -243,12 +261,11 @@ private List handleUserContent(Content content) { } List messages = new ArrayList<>(); - // Create UserMessage with text - // TODO: Media attachments support - UserMessage constructors with media are private in Spring - // AI 1.1.0 - // For now, only text content is supported - messages.add(new UserMessage(textBuilder.toString())); - messages.addAll(toolResponseMessages); + if (!toolResponseMessages.isEmpty()) { + messages.addAll(toolResponseMessages); + } else { + messages.add(UserMessage.builder().text(textBuilder.toString()).media(mediaList).build()); + } return messages; } diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java index 95dafadb4..c40bdc78e 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java @@ -120,34 +120,8 @@ public List convertToSpringAiTools(Map tools) { FunctionDeclaration declaration = tool.declaration().get(); // Create a ToolCallback that wraps the ADK tool - // Create a Function that takes Map input and calls the ADK tool - java.util.function.Function, String> toolFunction = - args -> { - try { - logger.debug("Spring AI calling tool '{}'", tool.name()); - logger.debug("Raw args from Spring AI: {}", args); - logger.debug("Args type: {}", args.getClass().getName()); - logger.debug("Args keys: {}", args.keySet()); - for (Map.Entry entry : args.entrySet()) { - logger.debug( - " {} -> {} ({})", - entry.getKey(), - entry.getValue(), - entry.getValue().getClass().getName()); - } - - // Handle different argument formats that Spring AI might pass - Map processedArgs = processArguments(args, declaration); - logger.debug("Processed args for ADK: {}", processedArgs); - - // Call the ADK tool and wait for the result - Map result = tool.runAsync(processedArgs, null).blockingGet(); - // Convert result back to JSON string - return new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(result); - } catch (Exception e) { - throw new RuntimeException("Tool execution failed: " + e.getMessage(), e); - } - }; + // Create a Function that does nothing. Function calling is done by ADK. + java.util.function.Function, String> toolFunction = args -> ""; FunctionToolCallback.Builder callbackBuilder = FunctionToolCallback.builder(tool.name(), toolFunction).description(tool.description()); @@ -181,54 +155,6 @@ public List convertToSpringAiTools(Map tools) { return toolCallbacks; } - /** - * Process arguments from Spring AI format to ADK format. Spring AI might pass arguments in - * different formats depending on the provider. - */ - private Map processArguments( - Map args, FunctionDeclaration declaration) { - // If the arguments already match the expected format, return as-is - if (declaration.parameters().isPresent()) { - var schema = declaration.parameters().get(); - if (schema.properties().isPresent()) { - var expectedParams = schema.properties().get().keySet(); - - // Check if all expected parameters are present at the top level - boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); - if (allParamsPresent) { - return args; - } - - // Check if arguments are nested under a single key (common pattern) - if (args.size() == 1) { - var singleValue = args.values().iterator().next(); - if (singleValue instanceof Map) { - @SuppressWarnings("unchecked") - Map nestedArgs = (Map) singleValue; - boolean allNestedParamsPresent = - expectedParams.stream().allMatch(nestedArgs::containsKey); - if (allNestedParamsPresent) { - return nestedArgs; - } - } - } - - // Check if we have a single parameter function and got a direct value - if (expectedParams.size() == 1) { - String expectedParam = expectedParams.iterator().next(); - if (args.size() == 1 && !args.containsKey(expectedParam)) { - // Try to map the single value to the expected parameter name - Object singleValue = args.values().iterator().next(); - return Map.of(expectedParam, singleValue); - } - } - } - } - - // If no processing worked, return original args and let ADK handle the error - return args; - } - /** Simple metadata holder for tool information. */ public static class ToolMetadata { private final String name; diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java deleted file mode 100644 index 301a145e0..000000000 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.models.springai; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.google.adk.tools.FunctionTool; -import java.lang.reflect.Method; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.Test; -import org.springframework.ai.tool.ToolCallback; - -/** Test argument processing logic in ToolConverter. */ -class ToolConverterArgumentProcessingTest { - - @Test - void testArgumentProcessingWithCorrectFormat() throws Exception { - // Create tool converter and tool - ToolConverter converter = new ToolConverter(); - FunctionTool tool = FunctionTool.create(WeatherTools.class, "getWeatherInfo"); - Map tools = Map.of("getWeatherInfo", tool); - - // Convert to Spring AI format - List toolCallbacks = converter.convertToSpringAiTools(tools); - assertThat(toolCallbacks).hasSize(1); - - // Test with correct argument format - ToolCallback callback = toolCallbacks.get(0); - Method processArguments = getProcessArgumentsMethod(converter); - - Map correctArgs = Map.of("location", "San Francisco"); - Map processedArgs = - invokeProcessArguments(processArguments, converter, correctArgs, tool.declaration().get()); - - assertThat(processedArgs).isEqualTo(correctArgs); - } - - @Test - void testArgumentProcessingWithNestedFormat() throws Exception { - ToolConverter converter = new ToolConverter(); - FunctionTool tool = FunctionTool.create(WeatherTools.class, "getWeatherInfo"); - - Method processArguments = getProcessArgumentsMethod(converter); - - // Test with nested arguments - Map nestedArgs = Map.of("args", Map.of("location", "San Francisco")); - Map processedArgs = - invokeProcessArguments(processArguments, converter, nestedArgs, tool.declaration().get()); - - assertThat(processedArgs).containsEntry("location", "San Francisco"); - } - - @Test - void testArgumentProcessingWithDirectValue() throws Exception { - ToolConverter converter = new ToolConverter(); - FunctionTool tool = FunctionTool.create(WeatherTools.class, "getWeatherInfo"); - - Method processArguments = getProcessArgumentsMethod(converter); - - // Test with single direct value (wrong key name) - Map directValueArgs = Map.of("value", "San Francisco"); - Map processedArgs = - invokeProcessArguments( - processArguments, converter, directValueArgs, tool.declaration().get()); - - // Should map the single value to the expected parameter name - assertThat(processedArgs).containsEntry("location", "San Francisco"); - } - - @Test - void testArgumentProcessingWithNoMatch() throws Exception { - ToolConverter converter = new ToolConverter(); - FunctionTool tool = FunctionTool.create(WeatherTools.class, "getWeatherInfo"); - - Method processArguments = getProcessArgumentsMethod(converter); - - // Test with completely wrong format - Map wrongArgs = Map.of("city", "San Francisco", "country", "USA"); - Map processedArgs = - invokeProcessArguments(processArguments, converter, wrongArgs, tool.declaration().get()); - - // Should return original args when no processing applies - assertThat(processedArgs).isEqualTo(wrongArgs); - } - - private Method getProcessArgumentsMethod(ToolConverter converter) throws Exception { - Method method = - ToolConverter.class.getDeclaredMethod( - "processArguments", Map.class, com.google.genai.types.FunctionDeclaration.class); - method.setAccessible(true); - return method; - } - - @SuppressWarnings("unchecked") - private Map invokeProcessArguments( - Method method, - ToolConverter converter, - Map args, - com.google.genai.types.FunctionDeclaration declaration) - throws Exception { - return (Map) method.invoke(converter, args, declaration); - } - - public static class WeatherTools { - public static Map getWeatherInfo(String location) { - return Map.of( - "location", location, - "temperature", "72°F", - "condition", "sunny and clear", - "humidity", "45%", - "forecast", "Perfect weather for outdoor activities!"); - } - } -}