From 46a856224cae958654c612128827acbbef93dac6 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Fri, 27 Jun 2025 11:47:12 +0200 Subject: [PATCH] Add option for immediate execution in McpSyncServer - The McpSyncServer wraps an async server. By default, reactive operations are scheduled on a bounded-elastic scheduler, to offload blocking work and prevent accidental blocking of non-blocking operations. - With the default behavior, there will be thead ops, even in a blocking context, which means thread-locals from the request thread will be lost. This is inconenvient for frameworks that store state in thread-locals. - This commit adds the ability to avoid offloading, when the user is sure they are executing code in a blocking environment. Work happens in the calling thread, and thread-locals are available throughout the execution. --- .../server/McpServer.java | 25 +++++++- .../server/McpServerFeatures.java | 59 +++++++++++-------- .../server/McpSyncServer.java | 28 ++++++++- ...rverTransportProviderIntegrationTests.java | 38 +++++++++++- .../transport/McpTestServletFilter.java | 43 ++++++++++++++ .../server/transport/TomcatTestUtil.java | 13 ++++ 6 files changed, 172 insertions(+), 34 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d6ec2cc30..637b7f92a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. */ package io.modelcontextprotocol.server; @@ -695,6 +695,8 @@ class SyncSpecification { private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private boolean immediateExecution = false; + private SyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; @@ -1116,6 +1118,22 @@ public SyncSpecification objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpAsyncServer}. Defaults to false, which does blocking code offloading + * to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public SyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + /** * Builds a synchronous MCP server that provides blocking operations. * @return A new instance of {@link McpSyncServer} configured with this builder's @@ -1125,12 +1143,13 @@ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory); - return new McpSyncServer(asyncServer); + return new McpSyncServer(asyncServer, this.immediateExecution); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d41..e61722a82 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. */ package io.modelcontextprotocol.server; @@ -95,28 +95,30 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * blocking code offloading to prevent accidental blocking of the non-blocking * transport. * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. * @return a specification which is protected from blocking calls specified by the * user. */ - static Async fromSync(Sync syncSpec) { + static Async fromSync(Sync syncSpec, boolean immediateExecution) { List tools = new ArrayList<>(); for (var tool : syncSpec.tools()) { - tools.add(AsyncToolSpecification.fromSync(tool)); + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); } Map resources = new HashMap<>(); syncSpec.resources().forEach((key, resource) -> { - resources.put(key, AsyncResourceSpecification.fromSync(resource)); + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); }); Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { - prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); }); Map completions = new HashMap<>(); syncSpec.completions().forEach((key, completion) -> { - completions.put(key, AsyncCompletionSpecification.fromSync(completion)); + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); }); List, Mono>> rootChangeConsumers = new ArrayList<>(); @@ -239,15 +241,15 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se public record AsyncToolSpecification(McpSchema.Tool tool, BiFunction, Mono> call) { - static AsyncToolSpecification fromSync(SyncToolSpecification tool) { + static AsyncToolSpecification fromSync(SyncToolSpecification tool, boolean immediate) { // FIXME: This is temporary, proper validation should be implemented if (tool == null) { return null; } - return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncToolSpecification(tool.tool(), (exchange, map) -> { + var toolResult = Mono.fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }); } } @@ -281,15 +283,16 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { public record AsyncResourceSpecification(McpSchema.Resource resource, BiFunction> readHandler) { - static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { // FIXME: This is temporary, proper validation should be implemented if (resource == null) { return null; } - return new AsyncResourceSpecification(resource.resource(), - (exchange, req) -> Mono - .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncResourceSpecification(resource.resource(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); } } @@ -327,15 +330,16 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { public record AsyncPromptSpecification(McpSchema.Prompt prompt, BiFunction> promptHandler) { - static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { // FIXME: This is temporary, proper validation should be implemented if (prompt == null) { return null; } - return new AsyncPromptSpecification(prompt.prompt(), - (exchange, req) -> Mono - .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncPromptSpecification(prompt.prompt(), (exchange, req) -> { + var promptResult = Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); } } @@ -366,14 +370,17 @@ public record AsyncCompletionSpecification(McpSchema.CompleteReference reference * @return an asynchronous wrapper of the provided sync specification, or * {@code null} if input is null */ - static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion) { + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { if (completion == null) { return null; } - return new AsyncCompletionSpecification(completion.referenceKey(), - (exchange, request) -> Mono.fromCallable( - () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncCompletionSpecification(completion.referenceKey(), (exchange, request) -> { + var completionResult = Mono.fromCallable( + () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 91f8d9e4c..5adda1a74 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -54,13 +54,27 @@ public class McpSyncServer { */ private final McpAsyncServer asyncServer; + private final boolean immediateExecution; + /** * Creates a new synchronous server that wraps the provided async server. * @param asyncServer The async server to wrap */ public McpSyncServer(McpAsyncServer asyncServer) { + this(asyncServer, false); + } + + /** + * Creates a new synchronous server that wraps the provided async server. + * @param asyncServer The async server to wrap + * @param immediateExecution Tools, prompts, and resources handlers execute work + * without blocking code offloading. Do NOT set to true if the {@code asyncServer}'s + * transport is non-blocking. + */ + public McpSyncServer(McpAsyncServer asyncServer, boolean immediateExecution) { Assert.notNull(asyncServer, "Async server must not be null"); this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; } /** @@ -68,7 +82,9 @@ public McpSyncServer(McpAsyncServer asyncServer) { * @param toolHandler The tool handler to add */ public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { - this.asyncServer.addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler)).block(); + this.asyncServer + .addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler, this.immediateExecution)) + .block(); } /** @@ -84,7 +100,10 @@ public void removeTool(String toolName) { * @param resourceHandler The resource handler to add */ public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { - this.asyncServer.addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler)).block(); + this.asyncServer + .addResource( + McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler, this.immediateExecution)) + .block(); } /** @@ -100,7 +119,10 @@ public void removeResource(String resourceUri) { * @param promptSpecification The prompt specification to add */ public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { - this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification)).block(); + this.asyncServer + .addPrompt( + McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, this.immediateExecution)) + .block(); } /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 4bd98b406..dcc7917d0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 - 2025 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -37,7 +37,6 @@ import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -46,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -728,6 +728,9 @@ void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + assertThat(McpTestServletFilter.getThreadLocalValue()) + .as("blocking code exectuion should be offloaded") + .isNull(); // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -758,6 +761,37 @@ void testToolCallSuccess() { mcpServer.close(); } + @Test + void testToolCallImmediateExecution() { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + var threadLocalValue = McpTestServletFilter.getThreadLocalValue(); + return CallToolResult.builder() + .addTextContent(threadLocalValue != null ? threadLocalValue : "") + .build(); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .immediateExecution(true) + .build(); + + try (var mcpClient = clientBuilder.build()) { + mcpClient.initialize(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).first() + .asInstanceOf(type(McpSchema.TextContent.class)) + .extracting(McpSchema.TextContent::text) + .isEqualTo(McpTestServletFilter.THREAD_LOCAL_VALUE); + } + + mcpServer.close(); + } + @Test void testToolListChangeHandlingSuccess() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java new file mode 100644 index 000000000..cc2543aa9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; + +/** + * Simple {@link Filter} which sets a value in a thread local. Used to verify whether MCP + * executions happen on the thread processing the request or are offloaded. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestServletFilter implements Filter { + + public static final String THREAD_LOCAL_VALUE = McpTestServletFilter.class.getName(); + + private static final ThreadLocal holder = new ThreadLocal<>(); + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + holder.set(THREAD_LOCAL_VALUE); + try { + filterChain.doFilter(servletRequest, servletResponse); + } + finally { + holder.remove(); + } + } + + public static String getThreadLocalValue() { + return holder.get(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index f61cdc413..5a3928e02 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -10,9 +10,12 @@ import jakarta.servlet.Servlet; import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.util.descriptor.web.FilterDef; +import org.apache.tomcat.util.descriptor.web.FilterMap; /** * @author Christian Tzolov + * @author Daniel Garnier-Moiroux */ public class TomcatTestUtil { @@ -39,6 +42,16 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se context.addChild(wrapper); context.addServletMappingDecoded("/*", "mcpServlet"); + var filterDef = new FilterDef(); + filterDef.setFilterClass(McpTestServletFilter.class.getName()); + filterDef.setFilterName(McpTestServletFilter.class.getSimpleName()); + context.addFilterDef(filterDef); + + var filterMap = new FilterMap(); + filterMap.setFilterName(McpTestServletFilter.class.getSimpleName()); + filterMap.addURLPattern("/*"); + context.addFilterMap(filterMap); + var connector = tomcat.getConnector(); connector.setAsyncTimeout(3000);