From 9621ff2332de235657e55df177606db31e2e254a Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 18 Mar 2025 18:57:22 +0100 Subject: [PATCH 1/4] refactor: improve MCP client timeout handling and reactive testing - Add configurable initialization timeout separate from request timeout - Rename ServletSse* test classes to HttpSse* for better naming consistency - Replace direct .block() calls with StepVerifier for better reactive testing - Change ping() method to return Mono instead of Mono - Improve error handling and reactive programming patterns throughout tests - Chain reactive operations for cleaner test flow Signed-off-by: Christian Tzolov --- .../client/WebFluxSseMcpAsyncClientTests.java | 7 - .../client/WebFluxSseMcpSyncClientTests.java | 7 - .../client/AbstractMcpAsyncClientTests.java | 255 ++++++++--------- .../client/AbstractMcpSyncClientTests.java | 17 +- .../client/McpAsyncClient.java | 19 +- .../client/McpClient.java | 33 ++- .../client/McpSyncClient.java | 7 +- .../client/AbstractMcpAsyncClientTests.java | 257 +++++++++--------- .../client/AbstractMcpSyncClientTests.java | 15 +- ...s.java => HttpSseMcpAsyncClientTests.java} | 7 +- ...ts.java => HttpSseMcpSyncClientTests.java} | 7 +- .../client/StdioMcpSyncClientTests.java | 2 +- 12 files changed, 339 insertions(+), 294 deletions(-) rename mcp/src/test/java/io/modelcontextprotocol/client/{ServletSseMcpAsyncClientTests.java => HttpSseMcpAsyncClientTests.java} (89%) rename mcp/src/test/java/io/modelcontextprotocol/client/{ServletSseMcpSyncClientTests.java => HttpSseMcpSyncClientTests.java} (89%) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 6cd74631..021ce465 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -48,9 +46,4 @@ public void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 6b980da4..20eeb1d5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -48,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index cdcba4d1..17cc9960 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -58,8 +58,12 @@ protected void onStart() { protected void onClose() { } - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } @BeforeEach @@ -69,7 +73,8 @@ void setUp() { assertThatCode(() -> { mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -78,8 +83,7 @@ void setUp() { @AfterEach void tearDown() { if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); } onClose(); } @@ -96,87 +100,93 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools"); + }).verify(); } @Test void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before pinging the server")) + .verify(); } @Test void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before calling tools")) + .verify(); } @Test void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull(); + assertThat(callToolResult.content()).isNotNull(); + assertThat(callToolResult.isError()).isNull(); + }) + .verifyComplete(); } @Test void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .expectError(Exception.class) + .verify(); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + StepVerifier.create(mcpAsyncClient.listResources(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resources")) + .verify(); } @Test void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test @@ -186,40 +196,44 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); + StepVerifier.create(mcpAsyncClient.listPrompts(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing prompts")) + .verify(); } @Test void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); + StepVerifier.create(mcpAsyncClient.getPrompt(request)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before getting prompts")) + .verify(); } @Test void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) .consumeNextWith(prompt -> { assertThat(prompt).isNotNull().satisfies(result -> { assertThat(result.messages()).isNotEmpty(); @@ -231,15 +245,16 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) + .expectErrorMatches(error -> error instanceof McpError && error.getMessage() + .equals("Client must be initialized before sending roots list changed notification")) + .verify(); } @Test void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); } @Test @@ -247,39 +262,39 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testAddRoot() { Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) + .verify(); } @Test void testRemoveRoot() { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) + .verify(); } @Test @@ -298,18 +313,20 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + StepVerifier.create(mcpAsyncClient.listResourceTemplates()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resource templates")) + .verify(); } @Test void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); } // @Test @@ -337,16 +354,13 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -356,15 +370,12 @@ void testInitializeWithSamplingCapability() { var capabilities = ClientCapabilities.builder().sampling().build(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -380,17 +391,17 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(samplingHandler) .build(); - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); + StepVerifier.create(client.initialize()).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + }).verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); } // --------------------------------------- @@ -399,19 +410,23 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before setting logging level")) + .verify(); } @Test void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } + StepVerifier.create(testAllLevels).verifyComplete(); } @Test @@ -420,20 +435,18 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index aeed06cb..ee43a572 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -52,8 +52,12 @@ public abstract class AbstractMcpSyncClientTests { abstract protected void onClose(); - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } @BeforeEach @@ -63,7 +67,8 @@ void setUp() { assertThatCode(() -> { mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -215,7 +220,7 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); @@ -313,7 +318,7 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) @@ -351,7 +356,7 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index b301aa93..4c5fd02c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -88,7 +88,6 @@ public class McpAsyncClient { /** * The max timeout to await for the client-server connection to be initialized. - * Usually x2 the request timeout. // TODO should we make it configurable? */ private final Duration initializationTimeout; @@ -151,18 +150,21 @@ public class McpAsyncClient { * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. + * @param initializationTimeout the max timeout to await for the client-server * @param features the MCP Client supported features. */ - McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, McpClientFeatures.Async features) { + McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Duration initializationTimeout, + McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); this.clientInfo = features.clientInfo(); this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); - this.initializationTimeout = requestTimeout.multipliedBy(2); + this.initializationTimeout = initializationTimeout; // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -367,12 +369,13 @@ private Mono withInitializationCheck(String actionName, /** * Sends a ping request to the server. - * @return A Mono that completes with the server's ping response + * @return A Mono that completes when the server responds to the ping */ - public Mono ping() { + public Mono ping() { return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - })); + }) + .then()); } // -------------------------- @@ -771,7 +774,9 @@ private NotificationHandler asyncLoggingNotificationHandler( * @see McpSchema.LoggingLevel */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { - Assert.notNull(loggingLevel, "Logging level must not be null"); + if (loggingLevel == null) { + return Mono.error(new McpError("Logging level must not be null")); + } return this.withInitializationCheck("setting logging level", initializedResult -> { String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index 7ab01b70..fa2690dc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -157,6 +157,8 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); @@ -193,6 +195,18 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public SyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -354,7 +368,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, asyncFeatures)); + return new McpSyncClient( + new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); } } @@ -381,6 +396,8 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); @@ -417,6 +434,18 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public AsyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -574,7 +603,7 @@ public AsyncSpec loggingConsumers( * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { - return new McpAsyncClient(this.transport, this.requestTimeout, + return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler)); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index e5d964b7..41f71d05 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -179,11 +179,10 @@ public void removeRoot(String rootUri) { } /** - * Send a synchronous ping request. - * @return + * Send a synchronous ping request to the server. */ - public Object ping() { - return this.delegate.ping().block(); + public void ping() { + this.delegate.ping().block(); } // -------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 661c629e..969c3a86 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -59,7 +59,11 @@ protected void onStart() { protected void onClose() { } - protected Duration getTimeoutDuration() { + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } @@ -70,7 +74,8 @@ void setUp() { assertThatCode(() -> { mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -79,105 +84,110 @@ void setUp() { @AfterEach void tearDown() { if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); } onClose(); } @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools"); + }).verify(); } @Test void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before pinging the server")) + .verify(); } @Test void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before calling tools")) + .verify(); } @Test void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull(); + assertThat(callToolResult.content()).isNotNull(); + assertThat(callToolResult.isError()).isNull(); + }) + .verifyComplete(); } @Test void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .expectError(Exception.class) + .verify(); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + StepVerifier.create(mcpAsyncClient.listResources(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resources")) + .verify(); } @Test void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test @@ -187,40 +197,44 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); + StepVerifier.create(mcpAsyncClient.listPrompts(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing prompts")) + .verify(); } @Test void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); + StepVerifier.create(mcpAsyncClient.getPrompt(request)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before getting prompts")) + .verify(); } @Test void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) .consumeNextWith(prompt -> { assertThat(prompt).isNotNull().satisfies(result -> { assertThat(result.messages()).isNotEmpty(); @@ -232,15 +246,16 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) + .expectErrorMatches(error -> error instanceof McpError && error.getMessage() + .equals("Client must be initialized before sending roots list changed notification")) + .verify(); } @Test void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); } @Test @@ -248,39 +263,39 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testAddRoot() { Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) + .verify(); } @Test void testRemoveRoot() { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) + .verify(); } @Test @@ -299,18 +314,20 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + StepVerifier.create(mcpAsyncClient.listResourceTemplates()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resource templates")) + .verify(); } @Test void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); } // @Test @@ -338,16 +355,13 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -357,15 +371,12 @@ void testInitializeWithSamplingCapability() { var capabilities = ClientCapabilities.builder().sampling().build(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -381,17 +392,17 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(samplingHandler) .build(); - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); + StepVerifier.create(client.initialize()).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + }).verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); } // --------------------------------------- @@ -400,19 +411,23 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before setting logging level")) + .verify(); } @Test void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } + StepVerifier.create(testAllLevels).verifyComplete(); } @Test @@ -421,20 +436,18 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 6f8cf198..a866bfb3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -53,7 +53,11 @@ public abstract class AbstractMcpSyncClientTests { abstract protected void onClose(); - protected Duration getTimeoutDuration() { + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } @@ -64,7 +68,8 @@ void setUp() { assertThatCode(() -> { mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -216,7 +221,7 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); @@ -314,7 +319,7 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) @@ -352,7 +357,7 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 89% rename from mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 7cc673fa..ac0fef24 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -18,7 +18,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { +class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { String host = "http://localhost:3004"; @@ -46,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java similarity index 89% rename from mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 2b8af41a..8772e620 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -18,7 +18,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncClientTests extends AbstractMcpSyncClientTests { +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { String host = "http://localhost:3003"; @@ -46,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 7ae65253..3517008c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -39,7 +39,7 @@ void customErrorHandlerShouldReceiveErrors() { ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().tryEmitNext(errorMessage); + ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, null); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } From 67bbce0c0bd728c3de44d454c1bfe4bc78dc2f1c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 08:27:38 +0100 Subject: [PATCH 2/4] adjust test timeout values Signed-off-by: Christian Tzolov --- .../client/WebFluxSseMcpAsyncClientTests.java | 6 ++++++ .../client/WebFluxSseMcpSyncClientTests.java | 6 ++++++ .../client/AbstractMcpAsyncClientTests.java | 6 +++--- .../client/AbstractMcpSyncClientTests.java | 8 +++++--- .../client/AbstractMcpSyncClientTests.java | 6 ++++-- .../client/StdioMcpAsyncClientTests.java | 6 ++++++ .../client/StdioMcpSyncClientTests.java | 9 +++------ 7 files changed, 33 insertions(+), 14 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 021ce465..0dccb27a 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -46,4 +48,8 @@ public void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 20eeb1d5..f5cab7b7 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -46,4 +48,8 @@ protected void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 17cc9960..2aa659ca 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -63,7 +63,7 @@ protected Duration getRequestTimeout() { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); + return Duration.ofSeconds(2); } @BeforeEach @@ -90,10 +90,10 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index ee43a572..d1b752fc 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -48,16 +48,18 @@ public abstract class AbstractMcpSyncClientTests { abstract protected ClientMcpTransport createMcpTransport(); - abstract protected void onStart(); + protected void onStart() { + } - abstract protected void onClose(); + protected void onClose() { + } protected Duration getRequestTimeout() { return Duration.ofSeconds(10); } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); + return Duration.ofSeconds(2); } @BeforeEach diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index a866bfb3..726632f3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -49,9 +49,11 @@ public abstract class AbstractMcpSyncClientTests { abstract protected ClientMcpTransport createMcpTransport(); - abstract protected void onStart(); + protected void onStart() { + } - abstract protected void onClose(); + protected void onClose() { + } protected Duration getRequestTimeout() { return Duration.ofSeconds(10); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index ce74812b..c285e2c6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; @@ -26,4 +28,8 @@ protected ClientMcpTransport createMcpTransport() { return new StdioClientTransport(stdioParams); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 3517008c..ec351623 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import java.time.Duration; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -44,12 +45,8 @@ void customErrorHandlerShouldReceiveErrors() { assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } - @Override - protected void onStart() { - } - - @Override - protected void onClose() { + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); } } From f534ff8ceeb57a3e0f47abe7cfeef551a1f2b958 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 10:04:37 +0100 Subject: [PATCH 3/4] Address review comments Signed-off-by: Christian Tzolov --- .../client/AbstractMcpAsyncClientTests.java | 3 ++- .../io/modelcontextprotocol/client/McpAsyncClient.java | 7 +++---- .../java/io/modelcontextprotocol/client/McpSyncClient.java | 7 ++++--- .../client/AbstractMcpAsyncClientTests.java | 3 ++- .../client/StdioMcpSyncClientTests.java | 3 ++- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 2aa659ca..91dd223c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -129,7 +129,8 @@ void testPingWithoutInitialization() { @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { + }).verifyComplete(); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 4c5fd02c..278e360d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -369,13 +369,12 @@ private Mono withInitializationCheck(String actionName, /** * Sends a ping request to the server. - * @return A Mono that completes when the server responds to the ping + * @return A Mono that completes with the server's ping response */ - public Mono ping() { + public Mono ping() { return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - }) - .then()); + })); } // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 41f71d05..e5d964b7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -179,10 +179,11 @@ public void removeRoot(String rootUri) { } /** - * Send a synchronous ping request to the server. + * Send a synchronous ping request. + * @return */ - public void ping() { - this.delegate.ping().block(); + public Object ping() { + return this.delegate.ping().block(); } // -------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 969c3a86..1bc40c52 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -130,7 +130,8 @@ void testPingWithoutInitialization() { @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { + }).verifyComplete(); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index ec351623..6d759b4b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -12,6 +12,7 @@ import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Sinks; import static org.assertj.core.api.Assertions.assertThat; @@ -40,7 +41,7 @@ void customErrorHandlerShouldReceiveErrors() { ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, null); + ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } From 7004cdbd351a9b1bb21ff0529a2f81127e64f6dd Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 10:26:14 +0100 Subject: [PATCH 4/4] Increase the request timeout to 14 sec Signed-off-by: Christian Tzolov --- .../client/AbstractMcpAsyncClientTests.java | 2 +- .../modelcontextprotocol/client/AbstractMcpSyncClientTests.java | 2 +- .../client/AbstractMcpAsyncClientTests.java | 2 +- .../modelcontextprotocol/client/AbstractMcpSyncClientTests.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 91dd223c..a8a59a63 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -59,7 +59,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index d1b752fc..0f83e31e 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -55,7 +55,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 1bc40c52..39bc4995 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -60,7 +60,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 726632f3..52a0138f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -56,7 +56,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() {