diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java new file mode 100644 index 00000000000..808ae9ddd39 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java @@ -0,0 +1,488 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import net.javacrumbs.jsonunit.assertj.JsonAssertions; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpComplete; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpMeta; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPrompt; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpSyncRequestContext; +import org.springaicommunity.mcp.context.StructuredElicitResult; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerStreamableHttpProperties; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.core.ResolvableType; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.test.util.TestSocketUtils; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.map; + +public class StreamableMcpAnnotations2IT { + + private final ApplicationContextRunner serverContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mcp.server.protocol=STREAMABLE") + .withConfiguration(AutoConfigurations.of(McpServerAutoConfiguration.class, + ToolCallbackConverterAutoConfiguration.class, McpServerStreamableHttpWebFluxAutoConfiguration.class, + McpServerAnnotationScannerAutoConfiguration.class, + McpServerSpecificationFactoryAutoConfiguration.class)); + + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, + McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, + McpClientSpecificationFactoryAutoConfiguration.class)); + + @Test + void clientServerCapabilities() { + + int serverPort = TestSocketUtils.findAvailableTcpPort(); + + this.serverContextRunner.withUserConfiguration(TestMcpServerConfiguration.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.server.name=test-mcp-server", + // "spring.ai.mcp.server.type=ASYNC", + // "spring.ai.mcp.server.protocol=SSE", + "spring.ai.mcp.server.version=1.0.0", + "spring.ai.mcp.server.streamable-http.keep-alive-interval=1s", + // "spring.ai.mcp.server.requestTimeout=1m", + "spring.ai.mcp.server.streamable-http.mcp-endpoint=/mcp") // @formatter:on + .run(serverContext -> { + // Verify all required beans are present + assertThat(serverContext).hasSingleBean(WebFluxStreamableServerTransportProvider.class); + assertThat(serverContext).hasSingleBean(RouterFunction.class); + assertThat(serverContext).hasSingleBean(McpSyncServer.class); + + // Verify server properties are configured correctly + McpServerProperties properties = serverContext.getBean(McpServerProperties.class); + assertThat(properties.getName()).isEqualTo("test-mcp-server"); + assertThat(properties.getVersion()).isEqualTo("1.0.0"); + + McpServerStreamableHttpProperties streamableHttpProperties = serverContext + .getBean(McpServerStreamableHttpProperties.class); + assertThat(streamableHttpProperties.getMcpEndpoint()).isEqualTo("/mcp"); + assertThat(streamableHttpProperties.getKeepAliveInterval()).isEqualTo(Duration.ofSeconds(1)); + + var httpServer = startHttpServer(serverContext, serverPort); + + this.clientApplicationContext.withUserConfiguration(TestMcpClientConfiguration.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:" + serverPort, + // "spring.ai.mcp.client.sse.connections.server1.url=http://localhost:" + serverPort, + // "spring.ai.mcp.client.request-timeout=20m", + "spring.ai.mcp.client.initialized=false") // @formatter:on + .run(clientContext -> { + McpSyncClient mcpClient = getMcpSyncClient(clientContext); + assertThat(mcpClient).isNotNull(); + var initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // TOOLS / SAMPLING / ELICITATION + + // tool list + assertThat(mcpClient.listTools().tools()).hasSize(2); + + // Call a tool that sends progress notifications + CallToolRequest toolRequest = CallToolRequest.builder() + .name("tool1") + .arguments(Map.of()) + .progressToken("test-progress-token") + .build(); + + CallToolResult response = mcpClient.callTool(toolRequest); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + String responseText = ((TextContent) response.content().get(0)).text(); + assertThat(responseText).contains("CALL RESPONSE"); + assertThat(responseText).contains("Response Test Sampling Message with model hint OpenAi"); + assertThat(responseText).contains("ElicitResult"); + + // PROGRESS + TestMcpClientConfiguration.TestContext testContext = clientContext + .getBean(TestMcpClientConfiguration.TestContext.class); + assertThat(testContext.progressLatch.await(5, TimeUnit.SECONDS)) + .as("Should receive progress notifications in reasonable time") + .isTrue(); + assertThat(testContext.progressNotifications).hasSize(3); + + Map notificationMap = testContext.progressNotifications + .stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("tool call start").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("tool call start").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("tool call start").total()).isEqualTo(1.0); + assertThat(notificationMap.get("tool call start").message()).isEqualTo("tool call start"); + + // Second notification should be 1.0/1.0 progress + assertThat(notificationMap.get("elicitation completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("elicitation completed").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("elicitation completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("elicitation completed").message()) + .isEqualTo("elicitation completed"); + + // Third notification should be 0.5/1.0 progress + assertThat(notificationMap.get("sampling completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("sampling completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").message()).isEqualTo("sampling completed"); + + // TOOL STRUCTURED OUTPUT + // Call tool with valid structured output + CallToolResult calculatorToolResponse = mcpClient.callTool(new McpSchema.CallToolRequest( + "calculator", Map.of("expression", "2 + 3"), Map.of("meta1", "value1"))); + + assertThat(calculatorToolResponse).isNotNull(); + assertThat(calculatorToolResponse.isError()).isFalse(); + + assertThat(calculatorToolResponse.structuredContent()).isNotNull(); + + assertThat(calculatorToolResponse.structuredContent()) + .asInstanceOf(map(String.class, Object.class)) + .containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + + JsonAssertions.assertThatJson(calculatorToolResponse.structuredContent()) + .when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(JsonAssertions.json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(calculatorToolResponse.meta()).containsEntry("meta1Response", "value1"); + + // RESOURCES + assertThat(mcpClient.listResources()).isNotNull(); + assertThat(mcpClient.listResources().resources()).hasSize(1); + assertThat(mcpClient.listResources().resources().get(0)) + .isEqualToComparingFieldByFieldRecursively(Resource.builder() + .uri("file://resource") + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build()); + + // PROMPT / COMPLETION + + // list prompts + assertThat(mcpClient.listPrompts()).isNotNull(); + assertThat(mcpClient.listPrompts().prompts()).hasSize(1); + + // get prompt + GetPromptResult promptResult = mcpClient + .getPrompt(new GetPromptRequest("code-completion", Map.of("language", "java"))); + assertThat(promptResult).isNotNull(); + var logMessage = testContext.loggingNotificationRef.get(); + assertThat(logMessage).isNotNull(); + assertThat(logMessage.level()).isEqualTo(LoggingLevel.INFO); + assertThat(logMessage.logger()).isEqualTo("test-logger"); + assertThat(logMessage.data()).contains("Hello java! How can I assist you today?"); + + // completion + CompleteRequest completeRequest = new CompleteRequest( + new PromptReference("ref/prompt", "code-completion", "Code completion"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult completeResult = mcpClient.completeCompletion(completeRequest); + + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion().total()).isEqualTo(10); + assertThat(completeResult.completion().values()).containsExactly("python", "pytorch", "pyside"); + assertThat(completeResult.meta()).isNull(); + + // logging message + logMessage = testContext.loggingNotificationRef.get(); + assertThat(logMessage).isNotNull(); + assertThat(logMessage.level()).isEqualTo(LoggingLevel.INFO); + assertThat(logMessage.logger()).isEqualTo("server"); + assertThat(logMessage.data()).contains("Code completion requested"); + + }); + + stopHttpServer(httpServer); + }); + } + + // Helper methods to start and stop the HTTP server + private static DisposableServer startHttpServer(ApplicationContext serverContext, int port) { + WebFluxStreamableServerTransportProvider mcpStreamableServerTransport = serverContext + .getBean(WebFluxStreamableServerTransportProvider.class); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + return HttpServer.create().port(port).handle(adapter).bindNow(); + } + + private static void stopHttpServer(DisposableServer server) { + if (server != null) { + server.disposeNow(); + } + } + + // Helper method to get the MCP sync client + private static McpSyncClient getMcpSyncClient(ApplicationContext clientContext) { + ObjectProvider> mcpClients = clientContext + .getBeanProvider(ResolvableType.forClassWithGenerics(List.class, McpSyncClient.class)); + return mcpClients.getIfAvailable().get(0); + } + + record ElicitInput(String message) { + } + + public static class TestMcpServerConfiguration { + + @Bean + public McpServerHandlers serverSideSpecProviders() { + return new McpServerHandlers(); + } + + public static class McpServerHandlers { + + @McpTool(description = "Test tool", name = "tool1") + public String toolWithSamplingAndElicitation(McpSyncRequestContext ctx, @McpToolParam String input) { + + ctx.info("Tool1 Started!"); + + ctx.progress(p -> p.progress(0.0).total(1.0).message("tool call start")); + + ctx.ping(); // call client ping + + // call elicitation + var elicitationResult = ctx.elicit(e -> e.message("Test message"), ElicitInput.class); + + ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); + + // call sampling + CreateMessageResult samplingResponse = ctx.sample(s -> s.message("Test Sampling Message") + .modelPreferences(pref -> pref.modelHints("OpenAi", "Ollama") + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0))); + + ctx.progress(p -> p.progress(1.0).total(1.0).message("sampling completed")); + + ctx.info("Tool1 Done!"); + + return "CALL RESPONSE: " + samplingResponse.toString() + ", " + elicitationResult.toString(); + } + + @McpTool(name = "calculator", description = "Performs mathematical calculations") + public CallToolResult calculator(@McpToolParam String expression, McpMeta meta) { + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .meta(Map.of("meta1Response", meta.get("meta1"))) + .build(); + } + + private static double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + + @McpResource(name = "Test Resource", uri = "file://resource", mimeType = "text/plain", + description = "Test resource description") + public ReadResourceResult testResource(McpSyncRequestContext ctx, ReadResourceRequest request) { + + ctx.ping(); + + try { + var systemInfo = Map.of("os", System.getProperty("os.name"), "os_version", + System.getProperty("os.version"), "java_version", System.getProperty("java.version")); + String jsonContent = new ObjectMapper().writeValueAsString(systemInfo); + return new ReadResourceResult(List + .of(new McpSchema.TextResourceContents(request.uri(), "application/json", jsonContent))); + } + catch (Exception e) { + throw new RuntimeException("Failed to generate system info", e); + } + } + + @McpPrompt(name = "code-completion", description = "this is code review prompt") + public GetPromptResult codeCompletionPrompt(McpSyncRequestContext ctx, + @McpArg(name = "language", required = false) String languageArgument) { + + String message = "Hello " + ((languageArgument == null) ? "java" : languageArgument) + + "! How can I assist you today?"; + + ctx.log(l -> l.logger("test-logger").message(message)); + + var userMessage = new PromptMessage(Role.USER, new TextContent(message)); + + return new GetPromptResult("A personalized greeting message", List.of(userMessage)); + } + + // the code-completion is a reference to the prompt code completion + @McpComplete(prompt = "code-completion") + public CompleteResult codeCompletion(McpSyncRequestContext ctx) { + ctx.info("Code completion requested"); + var expectedValues = List.of("python", "pytorch", "pyside"); + return new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + } + + } + + } + + public static class TestMcpClientConfiguration { + + @Bean + public TestContext testContext() { + return new TestContext(); + } + + @Bean + public TestMcpClientHandlers mcpClientHandlers(TestContext testContext) { + return new TestMcpClientHandlers(testContext); + } + + public static class TestContext { + + final AtomicReference loggingNotificationRef = new AtomicReference<>(); + + final CountDownLatch progressLatch = new CountDownLatch(3); + + final List progressNotifications = new CopyOnWriteArrayList<>(); + + } + + public static class TestMcpClientHandlers { + + private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); + + private TestContext testContext; + + public TestMcpClientHandlers(TestContext testContext) { + this.testContext = testContext; + } + + @McpProgress(clients = "server1") + public void progressHandler(ProgressNotification progressNotification) { + logger.info("MCP PROGRESS: [{}] progress: {} total: {} message: {}", + progressNotification.progressToken(), progressNotification.progress(), + progressNotification.total(), progressNotification.message()); + this.testContext.progressNotifications.add(progressNotification); + this.testContext.progressLatch.countDown(); + } + + @McpLogging(clients = "server1") + public void loggingHandler(LoggingMessageNotification loggingMessage) { + this.testContext.loggingNotificationRef.set(loggingMessage); + logger.info("MCP LOGGING: [{}] {}", loggingMessage.level(), loggingMessage.data()); + } + + @McpSampling(clients = "server1") + public CreateMessageResult samplingHandler(CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + + return CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + ElicitInput elicitData = new ElicitInput(request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + } + + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-examples.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-examples.adoc index 6afc272c794..a623f9d230c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-examples.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-examples.adoc @@ -55,15 +55,13 @@ public class CalculatorTools { description = "Calculate a complex mathematical expression") public CallToolResult calculateExpression( CallToolRequest request, - McpSyncServerExchange exchange) { + McpSyncRequestContext context) { Map args = request.arguments(); String expression = (String) args.get("expression"); - exchange.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .data("Calculating: " + expression) - .build()); + // Use convenient logging method + context.info("Calculating: " + expression); try { double result = evaluateExpression(expression); @@ -141,19 +139,20 @@ public class DocumentServer { @McpTool(name = "analyze-document", description = "Analyze document content") public String analyzeDocument( - @McpProgressToken String progressToken, + McpSyncRequestContext context, @McpToolParam(description = "Document ID", required = true) String docId, - @McpToolParam(description = "Analysis type", required = false) String type, - McpSyncServerExchange exchange) { + @McpToolParam(description = "Analysis type", required = false) String type) { Document doc = documents.get(docId); if (doc == null) { return "Document not found"; } + // Access progress token from context + String progressToken = context.request().progressToken(); + if (progressToken != null) { - exchange.progressNotification(new ProgressNotification( - progressToken, 0.0, 1.0, "Starting analysis")); + context.progress(p -> p.progress(0.0).total(1.0).message("Starting analysis")); } // Perform analysis @@ -161,8 +160,7 @@ public class DocumentServer { String result = performAnalysis(doc, analysisType); if (progressToken != null) { - exchange.progressNotification(new ProgressNotification( - progressToken, 1.0, 1.0, "Analysis complete")); + context.progress(p -> p.progress(1.0).total(1.0).message("Analysis complete")); } return result; @@ -381,21 +379,22 @@ public class AsyncDataProcessor { @McpTool(name = "process-stream", description = "Process data stream") public Flux processStream( - @McpToolParam(description = "Item count", required = true) int count, - @McpProgressToken String progressToken, - McpAsyncServerExchange exchange) { + McpAsyncRequestContext context, + @McpToolParam(description = "Item count", required = true) int count) { + + // Access progress token from context + String progressToken = context.request().progressToken(); return Flux.range(1, count) .delayElements(Duration.ofMillis(100)) - .doOnNext(i -> { + .flatMap(i -> { if (progressToken != null) { double progress = (double) i / count; - exchange.progressNotification(new ProgressNotification( - progressToken, progress, 1.0, - "Processing item " + i)); + return context.progress(p -> p.progress(progress).total(1.0).message("Processing item " + i)) + .thenReturn("Processed item " + i); } - }) - .map(i -> "Processed item " + i); + return Mono.just("Processed item " + i); + }); } @McpResource(uri = "async-data://{id}", name = "Async Data") diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-overview.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-overview.adoc index 3368ebd2f1c..c472ee375b4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-overview.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-overview.adoc @@ -36,11 +36,11 @@ For MCP Clients, the following annotations are provided: === Special Parameters and Annotations -* `McpSyncServerExchange` - Special parameter type for stateful synchronous operations that provides access to server exchange functionality including logging notifications, progress updates, and other server-side operations. This parameter is automatically injected and excluded from JSON schema generation -* `McpAsyncServerExchange` - Special parameter type for stateful asynchronous operations that provides access to server exchange functionality with reactive support. This parameter is automatically injected and excluded from JSON schema generation +* `McpSyncRequestContext` - Special parameter type for synchronous operations that provides a unified interface for accessing MCP request context, including the original request, server exchange (for stateful operations), transport context (for stateless operations), and convenient methods for logging, progress, sampling, and elicitation. This parameter is automatically injected and excluded from JSON schema generation. **Supported in Complete, Prompt, Resource, and Tool methods.** +* `McpAsyncRequestContext` - Special parameter type for asynchronous operations that provides the same unified interface as `McpSyncRequestContext` but with reactive (Mono-based) return types. This parameter is automatically injected and excluded from JSON schema generation. **Supported in Complete, Prompt, Resource, and Tool methods.** * `McpTransportContext` - Special parameter type for stateless operations that provides lightweight access to transport-level context without full server exchange functionality. This parameter is automatically injected and excluded from JSON schema generation -* `@McpProgressToken` - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema -* `McpMeta` - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation +* `@McpProgressToken` - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema. **Note:** When using `McpSyncRequestContext` or `McpAsyncRequestContext`, the progress token can be accessed via `ctx.request().progressToken()` instead of using this annotation. +* `McpMeta` - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation. **Note:** When using `McpSyncRequestContext` or `McpAsyncRequestContext`, metadata can be obtained via `ctx.requestMeta()` instead. == Getting Started diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-server.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-server.adoc index 53ecf6ac1db..369ed5a9d9e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-server.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-server.adoc @@ -45,26 +45,25 @@ public AreaResult calculateRectangleArea( } ---- -==== With Server Exchange +==== With Request Context -Tools can access the server exchange for advanced operations: +Tools can access the request context for advanced operations: [source,java] ---- -@McpTool(name = "process-data", description = "Process data with server context") +@McpTool(name = "process-data", description = "Process data with request context") public String processData( - McpSyncServerExchange exchange, + McpSyncRequestContext context, @McpToolParam(description = "Data to process", required = true) String data) { // Send logging notification - exchange.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .data("Processing data: " + data) - .build()); + context.info("Processing data: " + data); - // Send progress notification if progress token is available - exchange.progressNotification(new ProgressNotification( - progressToken, 0.5, 1.0, "Processing...")); + // Send progress notification (using convenient method) + context.progress(p -> p.progress(0.5).total(1.0).message("Processing...")); + + // Ping the client + context.ping(); return "Processed: " + data.toUpperCase(); } @@ -97,18 +96,18 @@ Tools can receive progress tokens for tracking long-running operations: ---- @McpTool(name = "long-task", description = "Long-running task with progress") public String performLongTask( - @McpProgressToken String progressToken, - @McpToolParam(description = "Task name", required = true) String taskName, - McpSyncServerExchange exchange) { + McpSyncRequestContext context, + @McpToolParam(description = "Task name", required = true) String taskName) { + + // Access progress token from context + String progressToken = context.request().progressToken(); if (progressToken != null) { - exchange.progressNotification(new ProgressNotification( - progressToken, 0.0, 1.0, "Starting task")); + context.progress(p -> p.progress(0.0).total(1.0).message("Starting task")); // Perform work... - exchange.progressNotification(new ProgressNotification( - progressToken, 1.0, 1.0, "Task completed")); + context.progress(p -> p.progress(1.0).total(1.0).message("Task completed")); } return "Task " + taskName + " completed"; @@ -156,22 +155,23 @@ public ReadResourceResult getUserProfile(String username) { } ---- -==== With Server Exchange +==== With Request Context [source,java] ---- @McpResource( uri = "data://{id}", name = "Data Resource", - description = "Resource with server context") + description = "Resource with request context") public ReadResourceResult getData( - McpSyncServerExchange exchange, + McpSyncRequestContext context, String id) { - exchange.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .data("Accessing resource: " + id) - .build()); + // Send logging notification using convenient method + context.info("Accessing resource: " + id); + + // Ping the client + context.ping(); String data = fetchData(id); @@ -302,53 +302,285 @@ public CompleteResult completeCode(String prefix) { == Stateless vs Stateful Implementations -=== Stateful (with McpSyncServerExchange/McpAsyncServerExchange) +=== Unified Request Context (Recommended) -Stateful implementations have access to the full server exchange context: +Use `McpSyncRequestContext` or `McpAsyncRequestContext` for a unified interface that works with both stateful and stateless operations: [source,java] ---- -@McpTool(name = "stateful-tool", description = "Tool with server exchange") -public String statefulTool( - McpSyncServerExchange exchange, +public record UserInfo(String name, String email, int age) {} + +@McpTool(name = "unified-tool", description = "Tool with unified request context") +public String unifiedTool( + McpSyncRequestContext context, @McpToolParam(description = "Input", required = true) String input) { - // Access server exchange features - exchange.loggingNotification(...); - exchange.progressNotification(...); - exchange.ping(); + // Access request and metadata + String progressToken = context.request().progressToken(); + + // Logging with convenient methods + context.info("Processing: " + input); + + // Progress notifications (Note client should set a progress token + // with its request to be able to receive progress updates) + context.progress(50); // Simple percentage - // Can call client methods - CreateMessageResult result = exchange.createMessage(...); - ElicitResult elicitResult = exchange.createElicitation(...); + // Ping client + context.ping(); - return "Processed with full context"; + // Check capabilities before using + if (context.elicitEnabled()) { + // Request user input (only in stateful mode) + StructuredElicitResult elicitResult = context.elicit(UserInfo.class); + if (elicitResult.action() == ElicitResult.Action.ACCEPT) { + // Use elicited data + } + } + + if (context.sampleEnabled()) { + // Request LLM sampling (only in stateful mode) + CreateMessageResult samplingResult = context.sample("Generate response"); + // Use sampling result + } + + return "Processed with unified context"; } ---- -=== Stateless (with McpTransportContext or without) +=== Simple Operations (No Context) -Stateless implementations are simpler and don't require server exchange: +For simple operations, you can omit context parameters entirely: [source,java] ---- -@McpTool(name = "stateless-tool", description = "Simple stateless tool") +@McpTool(name = "simple-add", description = "Simple addition") public int simpleAdd( @McpToolParam(description = "First number", required = true) int a, @McpToolParam(description = "Second number", required = true) int b) { return a + b; } +---- + +=== Lightweight Stateless (with McpTransportContext) -// With transport context if needed -@McpTool(name = "stateless-with-context", description = "Stateless with context") -public String withContext( +For stateless operations where you need minimal transport context: + +[source,java] +---- +@McpTool(name = "stateless-tool", description = "Stateless with transport context") +public String statelessTool( McpTransportContext context, @McpToolParam(description = "Input", required = true) String input) { - // Limited context access + // Access transport-level context only + // No bidirectional operations (roots, elicitation, sampling) return "Processed: " + input; } ---- +[IMPORTANT] +**Stateless servers do not support bidirectional operations:** + +Therefore methods using `McpSyncRequestContext` or `McpAsyncRequestContext` in stateless mode are ignored. + +== Method Filtering by Server Type + +The MCP annotations framework automatically filters annotated methods based on the server type and method characteristics. This ensures that only appropriate methods are registered for each server configuration. +A warning is logged for each filtered method to help with debugging. + +=== Synchronous vs Asynchronous Filtering + +==== Synchronous Servers + +Synchronous servers (configured with `spring.ai.mcp.server.type=SYNC`) use synchronous providers that: + +* **Accept** methods with non-reactive return types: + - Primitive types (`int`, `double`, `boolean`) + - Object types (`String`, `Integer`, custom POJOs) + - MCP types (`CallToolResult`, `ReadResourceResult`, `GetPromptResult`, `CompleteResult`) + - Collections (`List`, `Map`) + +* **Filter out** methods with reactive return types: + - `Mono` + - `Flux` + - `Publisher` + +[source,java] +---- +@Component +public class SyncTools { + + @McpTool(name = "sync-tool", description = "Synchronous tool") + public String syncTool(String input) { + // This method WILL be registered on sync servers + return "Processed: " + input; + } + + @McpTool(name = "async-tool", description = "Async tool") + public Mono asyncTool(String input) { + // This method will be FILTERED OUT on sync servers + // A warning will be logged + return Mono.just("Processed: " + input); + } +} +---- + +==== Asynchronous Servers + +Asynchronous servers (configured with `spring.ai.mcp.server.type=ASYNC`) use asynchronous providers that: + +* **Accept** methods with reactive return types: + - `Mono` (for single results) + - `Flux` (for streaming results) + - `Publisher` (generic reactive type) + +* **Filter out** methods with non-reactive return types: + - Primitive types + - Object types + - Collections + - MCP result types + +[source,java] +---- +@Component +public class AsyncTools { + + @McpTool(name = "async-tool", description = "Async tool") + public Mono asyncTool(String input) { + // This method WILL be registered on async servers + return Mono.just("Processed: " + input); + } + + @McpTool(name = "sync-tool", description = "Sync tool") + public String syncTool(String input) { + // This method will be FILTERED OUT on async servers + // A warning will be logged + return "Processed: " + input; + } +} +---- + +=== Stateful vs Stateless Filtering + +==== Stateful Servers + +Stateful servers support bidirectional communication and accept methods with: + +* **Bidirectional context parameters**: + - `McpSyncRequestContext` (for sync operations) + - `McpAsyncRequestContext` (for async operations) + - `McpSyncServerExchange` (legacy, for sync operations) + - `McpAsyncServerExchange` (legacy, for async operations) + +* Support for bidirectional operations: + - `roots()` - Access root directories + - `elicit()` - Request user input + - `sample()` - Request LLM sampling + +[source,java] +---- +@Component +public class StatefulTools { + + @McpTool(name = "interactive-tool", description = "Tool with bidirectional operations") + public String interactiveTool( + McpSyncRequestContext context, + @McpToolParam(description = "Input", required = true) String input) { + + // This method WILL be registered on stateful servers + // Can use elicitation, sampling, roots + if (context.sampleEnabled()) { + var samplingResult = context.sample("Generate response"); + // Process sampling result... + } + + return "Processed with context"; + } +} +---- + +==== Stateless Servers + +Stateless servers are optimized for simple request-response patterns and: + +* **Filter out** methods with bidirectional context parameters: + - Methods with `McpSyncRequestContext` are skipped + - Methods with `McpAsyncRequestContext` are skipped + - Methods with `McpSyncServerExchange` are skipped + - Methods with `McpAsyncServerExchange` are skipped + - A warning is logged for each filtered method + +* **Accept** methods with: + - `McpTransportContext` (lightweight stateless context) + - No context parameter at all + - Only regular `@McpToolParam` parameters + +* Do **not** support bidirectional operations: + - `roots()` - Not available + - `elicit()` - Not available + - `sample()` - Not available + +[source,java] +---- +@Component +public class StatelessTools { + + @McpTool(name = "simple-tool", description = "Simple stateless tool") + public String simpleTool(@McpToolParam(description = "Input") String input) { + // This method WILL be registered on stateless servers + return "Processed: " + input; + } + + @McpTool(name = "context-tool", description = "Tool with transport context") + public String contextTool( + McpTransportContext context, + @McpToolParam(description = "Input") String input) { + // This method WILL be registered on stateless servers + return "Processed: " + input; + } + + @McpTool(name = "bidirectional-tool", description = "Tool with bidirectional context") + public String bidirectionalTool( + McpSyncRequestContext context, + @McpToolParam(description = "Input") String input) { + // This method will be FILTERED OUT on stateless servers + // A warning will be logged + return "Processed with sampling"; + } +} +---- + +=== Filtering Summary + +[cols="1,2,2"] +|=== +|Server Type |Accepted Methods |Filtered Methods + +|**Sync Stateful** +|Non-reactive returns + bidirectional context +|Reactive returns (Mono/Flux) + +|**Async Stateful** +|Reactive returns (Mono/Flux) + bidirectional context +|Non-reactive returns + +|**Sync Stateless** +|Non-reactive returns + no bidirectional context +|Reactive returns OR bidirectional context parameters + +|**Async Stateless** +|Reactive returns (Mono/Flux) + no bidirectional context +|Non-reactive returns OR bidirectional context parameters +|=== + +[TIP] +**Best Practices for Method Filtering:** + +1. **Keep methods aligned** with your server type - use sync methods for sync servers, async for async servers +2. **Separate stateful and stateless** implementations into different classes for clarity +3. **Check logs** during startup for filtered method warnings +4. **Use the right context** - `McpSyncRequestContext`/`McpAsyncRequestContext` for stateful, `McpTransportContext` for stateless +5. **Test both modes** if you support both stateful and stateless deployments + == Async Support All server annotations support asynchronous implementations using Reactor: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-special-params.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-special-params.adoc index 86c5f404c8d..2239c5c1f72 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-special-params.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/mcp/mcp-annotations-special-params.adoc @@ -161,87 +161,95 @@ public ReadResourceResult getLargeFile( } ---- -=== McpSyncServerExchange / McpAsyncServerExchange +=== McpSyncRequestContext / McpAsyncRequestContext -Server exchange objects provide full access to server-side MCP operations. +Request context objects provide unified access to MCP request information and server-side operations. ==== Overview -* Provides stateful context for server operations +* Provides unified interface for both stateful and stateless operations * Automatically injected when used as a parameter * Excluded from JSON schema generation -* Enables advanced features like logging, progress notifications, and client calls +* Enables advanced features like logging, progress notifications, sampling, and elicitation +* Works with both stateful (server exchange) and stateless (transport context) modes -==== McpSyncServerExchange Features +==== McpSyncRequestContext Features [source,java] ---- +public record UserInfo(String name, String email, int age) {} + @McpTool(name = "advanced-tool", description = "Tool with full server capabilities") public String advancedTool( - McpSyncServerExchange exchange, + McpSyncRequestContext context, @McpToolParam(description = "Input", required = true) String input) { // Send logging notification - exchange.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .logger("advanced-tool") - .data("Processing: " + input) - .build()); + context.info("Processing: " + input); // Ping the client - exchange.ping(); - - // Request additional information from user - ElicitRequest elicitRequest = ElicitRequest.builder() - .message("Need additional information") - .requestedSchema(Map.of( - "type", "object", - "properties", Map.of( - "confirmation", Map.of("type", "boolean") - ) - )) - .build(); - - ElicitResult elicitResult = exchange.createElicitation(elicitRequest); - - // Request LLM sampling - CreateMessageRequest messageRequest = CreateMessageRequest.builder() - .messages(List.of(new SamplingMessage(Role.USER, - new TextContent("Process: " + input)))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of(ModelHint.of("gpt-4"))) - .build()) - .build(); + context.ping(); + + // Send progress updates + context.progress(50); // 50% complete + + // Check if elicitation is supported before using it + if (context.elicitEnabled()) { + // Request additional information from user + StructuredElicitResult elicitResult = context.elicit( + e -> e.message("Need additional information"), + UserInfo.class + ); + + if (elicitResult.action() == ElicitResult.Action.ACCEPT) { + UserInfo userInfo = elicitResult.structuredContent(); + // Use the user information + } + } - CreateMessageResult samplingResult = exchange.createMessage(messageRequest); + // Check if sampling is supported before using it + if (context.sampleEnabled()) { + // Request LLM sampling + CreateMessageResult samplingResult = context.sample( + s -> s.message("Process: " + input) + .modelPreferences(pref -> pref.modelHints("gpt-4")) + ); + } return "Processed with advanced features"; } ---- -==== McpAsyncServerExchange Features +==== McpAsyncRequestContext Features [source,java] ---- +public record UserInfo(String name, String email, int age) {} + @McpTool(name = "async-advanced-tool", description = "Async tool with server capabilities") public Mono asyncAdvancedTool( - McpAsyncServerExchange exchange, + McpAsyncRequestContext context, @McpToolParam(description = "Input", required = true) String input) { - return Mono.fromCallable(() -> { - // Send async logging - exchange.loggingNotification(LoggingMessageNotification.builder() - .level(LoggingLevel.INFO) - .data("Async processing: " + input) - .build()); - - return "Started processing"; - }) - .flatMap(msg -> { - // Chain async operations - return exchange.createMessage(/* request */) - .map(result -> "Completed: " + result); - }); + return context.info("Async processing: " + input) + .then(context.progress(25)) + .then(context.ping()) + .flatMap(v -> { + // Perform elicitation if supported + if (context.elicitEnabled()) { + return context.elicitation(UserInfo.class) + .map(userInfo -> "Processing for user: " + userInfo.name()); + } + return Mono.just("Processing..."); + }) + .flatMap(msg -> { + // Perform sampling if supported + if (context.sampleEnabled()) { + return context.sampling("Process: " + input) + .map(result -> "Completed: " + result); + } + return Mono.just("Completed: " + msg); + }); } ---- @@ -458,10 +466,42 @@ public String safeProgress( === Choose the Right Context -* Use `McpSyncServerExchange` / `McpAsyncServerExchange` for stateful operations -* Use `McpTransportContext` for simple stateless operations +* Use `McpSyncRequestContext` / `McpAsyncRequestContext` for unified access to request context, supporting both stateful and stateless operations with convenient helper methods +* Use `McpTransportContext` for simple stateless operations when you only need transport-level context * Omit context parameters entirely for the simplest cases +=== Capability Checking + +Always check capability support before using client features: + +[source,java] +---- +@McpTool(name = "capability-aware", description = "Tool that checks capabilities") +public String capabilityAware( + McpSyncRequestContext context, + @McpToolParam(description = "Data", required = true) String data) { + + // Check if elicitation is supported before using it + if (context.elicitEnabled()) { + // Safe to use elicitation + var result = context.elicit(UserInfo.class); + // Process result... + } + + // Check if sampling is supported before using it + if (context.sampleEnabled()) { + // Safe to use sampling + var samplingResult = context.sample("Process: " + data); + // Process result... + } + + // Note: Stateless servers do not support bidirectional operations + // (roots, elicitation, sampling) and will return false for these checks + + return "Processed with capability awareness"; +} +---- + == Additional Resources * xref:api/mcp/mcp-annotations-overview.adoc[MCP Annotations Overview]