From 2ab0e257d2055828782b1ebc3ea5fc160b1a4640 Mon Sep 17 00:00:00 2001 From: Phani Pemmaraju Date: Mon, 22 Sep 2025 10:57:58 +0100 Subject: [PATCH] fix(session): always dispose in closeGracefully --- .../spec/DefaultMcpTransportSession.java | 33 ++++- .../spec/DefaultMcpTransportSessionTests.java | 110 +++++++++++++++ ...luxSseCloseGracefullyIntegrationTests.java | 131 ++++++++++++++++++ 3 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/spec/DefaultMcpTransportSessionTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseCloseGracefullyIntegrationTests.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index fdb7bfd89..c3dbeb869 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.Disposables; +import reactor.core.Exceptions; import reactor.core.publisher.Mono; import java.util.Optional; @@ -77,8 +78,36 @@ public void close() { @Override public Mono closeGracefully() { - return Mono.from(this.onClose.apply(this.sessionId.get())) - .then(Mono.fromRunnable(this.openConnections::dispose)); + return Mono.defer(() -> { + final String sessionId = this.sessionId.get(); + + final AtomicReference primary = new AtomicReference<>(null); + + // Subscribe to onClose publisher and capture any error + return Mono.from(this.onClose.apply(sessionId)).onErrorResume(err -> { + primary.set(err); + return Mono.empty(); + }) + // Always dispose openConnections + .then(Mono.defer(() -> { + try { + this.openConnections.dispose(); + } + catch (Throwable disposeEx) { + if (primary.get() != null) { + primary.get().addSuppressed(disposeEx); + } + else { + primary.set(disposeEx); + } + } + + // Re-emit the original error (with suppressed dispose error), + // complete + Throwable throwable = primary.get(); + return (throwable == null) ? Mono.empty() : Mono.error(Exceptions.propagate(throwable)); + })); + }); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/DefaultMcpTransportSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/DefaultMcpTransportSessionTests.java new file mode 100644 index 000000000..756221e2c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/DefaultMcpTransportSessionTests.java @@ -0,0 +1,110 @@ +package io.modelcontextprotocol.spec; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.springframework.util.ReflectionUtils; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; + +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultMcpTransportSession}. + * + * @author Phani Pemmaraju + */ +class DefaultMcpTransportSessionTests { + + /** Minimal Disposable to flag that dispose() was called. */ + static final class FlagDisposable implements Disposable { + + final AtomicBoolean disposed = new AtomicBoolean(false); + + @Override + public void dispose() { + disposed.set(true); + } + + @Override + public boolean isDisposed() { + return disposed.get(); + } + + } + + @Test + void closeGracefully_disposes_when_onClose_throws() { + @SuppressWarnings("unchecked") + Function> onClose = Mockito.mock(Function.class); + Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception"))); + + // construct session with required ctor + var session = new DefaultMcpTransportSession(onClose); + + // seed session id + setField(session, "sessionId", new AtomicReference<>("sessionId-123")); + + // get the existing final composite and add a child flag-disposable + Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections"); + FlagDisposable flag = new FlagDisposable(); + composite.add(flag); + + // act + assert: original onClose error is propagated + assertThatThrownBy(() -> session.closeGracefully().block()).isInstanceOf(RuntimeException.class) + .hasMessageContaining("runtime-exception"); + + // and the child disposable was disposed => proves composite.dispose() executed + assertThat(flag.isDisposed()).isTrue(); + } + + @Test + void closeGracefully_propagates_onClose_error_and_disposes_children() { + // onClose fails again + @SuppressWarnings("unchecked") + Function> onClose = Mockito.mock(Function.class); + Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception"))); + + var session = new DefaultMcpTransportSession(onClose); + setField(session, "sessionId", new AtomicReference<>("sessionId-xyz")); + + Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections"); + FlagDisposable a = new FlagDisposable(); + FlagDisposable b = new FlagDisposable(); + composite.add(a); + composite.add(b); + + Throwable thrown = Assertions.catchThrowable(() -> session.closeGracefully().block()); + + // primary error is from onClose + assertThat(thrown).isInstanceOf(RuntimeException.class).hasMessageContaining("runtime-exception"); + + // both children disposed + assertThat(a.isDisposed()).isTrue(); + assertThat(b.isDisposed()).isTrue(); + } + + private static void setField(Object target, String fieldName, Object value) { + Field f = ReflectionUtils.findField(target.getClass(), fieldName); + if (f == null) + throw new IllegalArgumentException("No such field: " + fieldName); + ReflectionUtils.makeAccessible(f); + ReflectionUtils.setField(f, target, value); + } + + private static Object getField(Object target, String fieldName) { + Field f = ReflectionUtils.findField(target.getClass(), fieldName); + if (f == null) + throw new IllegalArgumentException("No such field: " + fieldName); + ReflectionUtils.makeAccessible(f); + return ReflectionUtils.getField(f, target); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseCloseGracefullyIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseCloseGracefullyIntegrationTests.java new file mode 100644 index 000000000..c43dbc814 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseCloseGracefullyIntegrationTests.java @@ -0,0 +1,131 @@ +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import reactor.core.publisher.Hooks; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +@Timeout(15) +public class WebFluxSseCloseGracefullyIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private int port; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String DEFAULT_MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransportProvider mcpServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofSeconds(10))); + + clientBuilders.put("webflux", McpClient + .sync(WebFluxSseClientTransport.builder(org.springframework.web.reactive.function.client.WebClient.builder() + .baseUrl("http://localhost:" + port)).sseEndpoint(CUSTOM_SSE_ENDPOINT).build()) + .requestTimeout(Duration.ofSeconds(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); + } + + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); + } + + @BeforeEach + void before() { + // Build the transport provider with BOTH endpoints (message required) + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .messageEndpoint(DEFAULT_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); + + // Wire session factory + prepareSyncServerBuilder().build(); + + // Bind on ephemeral port and discover the actual port + var httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + var adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(0).handle(adapter).bindNow(); + this.port = httpServer.port(); + + // Build clients using the discovered port + prepareClients(this.port, null); + + // keep your onErrorDropped suppression if you need it for noisy Reactor paths + Hooks.onErrorDropped(e -> { + }); + } + + @AfterEach + void after() { + if (httpServer != null) + httpServer.disposeNow(); + Hooks.resetOnErrorDropped(); + } + + @ParameterizedTest(name = "closeGracefully after outage: {0}") + @MethodSource("clientsForTesting") + @DisplayName("closeGracefully() signals failure after server outage (WebFlux/SSE, sync client)") + void closeGracefully_disposes_after_server_unavailable(String clientKey) { + var reactiveClient = io.modelcontextprotocol.client.McpClient + .async(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + this.port)) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofSeconds(10)) + .build(); + + reactiveClient.initialize().block(Duration.ofSeconds(5)); + + httpServer.disposeNow(); + + Assertions.assertThatCode(() -> reactiveClient.closeGracefully().block(Duration.ofSeconds(5))) + .doesNotThrowAnyException(); + + Assertions.assertThatThrownBy(() -> reactiveClient.initialize().block(Duration.ofSeconds(3))) + .isInstanceOf(Exception.class); + + } + +}