messageEndpointSink = Sinks.one();
-
- /**
- * The SSE endpoint URI provided by the server. Used for sending outbound messages via
- * HTTP POST requests.
- */
- private String sseEndpoint;
-
- /**
- * Constructs a new SseClientTransport with the specified WebClient builder. Uses a
- * default ObjectMapper instance for JSON processing.
- * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
- * instance
- * @throws IllegalArgumentException if webClientBuilder is null
- */
- public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) {
- this(webClientBuilder, new ObjectMapper());
- }
-
- /**
- * Constructs a new SseClientTransport with the specified WebClient builder and
- * ObjectMapper. Initializes both inbound and outbound message processing pipelines.
- * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
- * instance
- * @param objectMapper the ObjectMapper to use for JSON processing
- * @throws IllegalArgumentException if either parameter is null
- */
- public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
- this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT);
- }
-
- /**
- * Constructs a new SseClientTransport with the specified WebClient builder and
- * ObjectMapper. Initializes both inbound and outbound message processing pipelines.
- * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
- * instance
- * @param objectMapper the ObjectMapper to use for JSON processing
- * @param sseEndpoint the SSE endpoint URI to use for establishing the connection
- * @throws IllegalArgumentException if either parameter is null
- */
- public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper,
- String sseEndpoint) {
- Assert.notNull(objectMapper, "ObjectMapper must not be null");
- Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
- Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty");
-
- this.objectMapper = objectMapper;
- this.webClient = webClientBuilder.build();
- this.sseEndpoint = sseEndpoint;
- }
-
- /**
- * Establishes a connection to the MCP server using Server-Sent Events (SSE). This
- * method initiates the SSE connection and sets up the message processing pipeline.
- *
- *
- * The connection process follows these steps:
- *
- * Establishes an SSE connection to the server's /sse endpoint
- * Waits for the server to send an 'endpoint' event with the message posting
- * URI
- * Sets up message handling for incoming JSON-RPC messages
- *
- *
- *
- * The connection is considered established only after receiving the endpoint event
- * from the server.
- * @param handler a function that processes incoming JSON-RPC messages and returns
- * responses
- * @return a Mono that completes when the connection is fully established
- * @throws McpError if there's an error processing SSE events or if an unrecognized
- * event type is received
- */
- @Override
- public Mono connect(Function, Mono> handler) {
- Flux> events = eventStream();
- this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> {
- if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
- String messageEndpointUri = event.data();
- if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
- s.complete();
- }
- else {
- // TODO: clarify with the spec if multiple events can be
- // received
- s.error(new McpError("Failed to handle SSE endpoint event"));
- }
- }
- else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
- try {
- JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
- s.next(message);
- }
- catch (IOException ioException) {
- s.error(ioException);
- }
- }
- else {
- s.error(new McpError("Received unrecognized SSE event type: " + event.event()));
- }
- }).transform(handler)).subscribe();
-
- // The connection is established once the server sends the endpoint event
- return messageEndpointSink.asMono().then();
- }
-
- /**
- * Sends a JSON-RPC message to the server using the endpoint provided during
- * connection.
- *
- *
- * Messages are sent via HTTP POST requests to the server-provided endpoint URI. The
- * message is serialized to JSON before transmission. If the transport is in the
- * process of closing, the message send operation is skipped gracefully.
- * @param message the JSON-RPC message to send
- * @return a Mono that completes when the message has been sent successfully
- * @throws RuntimeException if message serialization fails
- */
- @Override
- public Mono sendMessage(JSONRPCMessage message) {
- // The messageEndpoint is the endpoint URI to send the messages
- // It is provided by the server as part of the endpoint event
- return messageEndpointSink.asMono().flatMap(messageEndpointUri -> {
- if (isClosing) {
- return Mono.empty();
- }
- try {
- String jsonText = this.objectMapper.writeValueAsString(message);
- return webClient.post()
- .uri(messageEndpointUri)
- .contentType(MediaType.APPLICATION_JSON)
- .bodyValue(jsonText)
- .retrieve()
- .toBodilessEntity()
- .doOnSuccess(response -> {
- logger.debug("Message sent successfully");
- })
- .doOnError(error -> {
- if (!isClosing) {
- logger.error("Error sending message: {}", error.getMessage());
- }
- });
- }
- catch (IOException e) {
- if (!isClosing) {
- return Mono.error(new RuntimeException("Failed to serialize message", e));
- }
- return Mono.empty();
- }
- }).then(); // TODO: Consider non-200-ok response
- }
-
- /**
- * Initializes and starts the inbound SSE event processing. Establishes the SSE
- * connection and sets up event handling for both message and endpoint events.
- * Includes automatic retry logic for handling transient connection failures.
- */
- // visible for tests
- protected Flux> eventStream() {// @formatter:off
- return this.webClient
- .get()
- .uri(this.sseEndpoint)
- .accept(MediaType.TEXT_EVENT_STREAM)
- .retrieve()
- .bodyToFlux(SSE_TYPE)
- .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler)));
- } // @formatter:on
-
- /**
- * Retry handler for the inbound SSE stream. Implements the retry logic for handling
- * connection failures and other errors.
- */
- private BiConsumer> inboundRetryHandler = (retrySpec, sink) -> {
- if (isClosing) {
- logger.debug("SSE connection closed during shutdown");
- sink.error(retrySpec.failure());
- return;
- }
- if (retrySpec.failure() instanceof IOException) {
- logger.debug("Retrying SSE connection after IO error");
- sink.next(retrySpec);
- return;
- }
- logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage());
- sink.error(retrySpec.failure());
- };
-
- /**
- * Implements graceful shutdown of the transport. Cleans up all resources including
- * subscriptions and schedulers. Ensures orderly shutdown of both inbound and outbound
- * message processing.
- * @return a Mono that completes when shutdown is finished
- */
- @Override
- public Mono closeGracefully() { // @formatter:off
- return Mono.fromRunnable(() -> {
- isClosing = true;
-
- // Dispose of subscriptions
-
- if (inboundSubscription != null) {
- inboundSubscription.dispose();
- }
-
- })
- .then()
- .subscribeOn(Schedulers.boundedElastic());
- } // @formatter:on
-
- /**
- * Unmarshalls data from a generic Object into the specified type using the configured
- * ObjectMapper.
- *
- *
- * This method is particularly useful when working with JSON-RPC parameters or result
- * objects that need to be converted to specific Java types. It leverages Jackson's
- * type conversion capabilities to handle complex object structures.
- * @param the target type to convert the data into
- * @param data the source object to convert
- * @param typeRef the TypeReference describing the target type
- * @return the unmarshalled object of type T
- * @throws IllegalArgumentException if the conversion cannot be performed
- */
- @Override
- public T unmarshalFrom(Object data, TypeReference typeRef) {
- return this.objectMapper.convertValue(data, typeRef);
- }
-
- /**
- * Creates a new builder for {@link WebFluxSseClientTransport}.
- * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
- * instance
- * @return a new builder instance
- */
- public static Builder builder(WebClient.Builder webClientBuilder) {
- return new Builder(webClientBuilder);
- }
-
- /**
- * Builder for {@link WebFluxSseClientTransport}.
- */
- public static class Builder {
-
- private final WebClient.Builder webClientBuilder;
-
- private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
-
- private ObjectMapper objectMapper = new ObjectMapper();
-
- /**
- * Creates a new builder with the specified WebClient.Builder.
- * @param webClientBuilder the WebClient.Builder to use
- */
- public Builder(WebClient.Builder webClientBuilder) {
- Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
- this.webClientBuilder = webClientBuilder;
- }
-
- /**
- * Sets the SSE endpoint path.
- * @param sseEndpoint the SSE endpoint path
- * @return this builder
- */
- public Builder sseEndpoint(String sseEndpoint) {
- Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
- this.sseEndpoint = sseEndpoint;
- return this;
- }
-
- /**
- * Sets the object mapper for JSON serialization/deserialization.
- * @param objectMapper the object mapper
- * @return this builder
- */
- public Builder objectMapper(ObjectMapper objectMapper) {
- Assert.notNull(objectMapper, "objectMapper must not be null");
- this.objectMapper = objectMapper;
- return this;
- }
-
- /**
- * Builds a new {@link WebFluxSseClientTransport} instance.
- * @return a new transport instance
- */
- public WebFluxSseClientTransport build() {
- return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint);
- }
-
- }
-
-}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java
new file mode 100644
index 00000000..2033682d
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java
@@ -0,0 +1,462 @@
+package io.modelcontextprotocol.client.transport;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.util.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.http.MediaType;
+import org.springframework.http.codec.ServerSentEvent;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.Disposable;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.core.publisher.SynchronousSink;
+import reactor.core.scheduler.Schedulers;
+import reactor.util.retry.Retry;
+
+import java.io.IOException;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+import static io.modelcontextprotocol.util.Utils.getSessionIdFromUrl;
+
+/**
+ * Server-Sent Events (SSE) implementation of the
+ * {@link io.modelcontextprotocol.spec.McpClientTransportProvider} that provides SSE
+ * client transport that follows the MCP HTTP with SSE transport specification.
+ *
+ *
+ * This transport establishes a bidirectional communication channel where:
+ *
+ * Inbound messages are received through an SSE connection from the server
+ * Outbound messages are sent via HTTP POST requests to a server-provided
+ * endpoint
+ *
+ *
+ *
+ * The message flow follows these steps:
+ *
+ * The client establishes an SSE connection to the server's /sse endpoint
+ * The server sends an 'endpoint' event containing the URI for sending messages
+ *
+ *
+ * This implementation uses {@link WebClient} for HTTP communications and supports JSON
+ * serialization/deserialization of messages.
+ *
+ * @author Christian Tzolov
+ * @author Jermaine Hua
+ * @see MCP
+ * HTTP with SSE Transport Specification
+ */
+public class WebFluxSseClientTransportProvider implements McpClientTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransportProvider.class);
+
+ /**
+ * Event type for JSON-RPC messages received through the SSE connection. The server
+ * sends messages with this event type to transmit JSON-RPC protocol data.
+ */
+ private static final String MESSAGE_EVENT_TYPE = "message";
+
+ /**
+ * Event type for receiving the message endpoint URI from the server. The server MUST
+ * send this event when a client connects, providing the URI where the client should
+ * send its messages via HTTP POST.
+ */
+ private static final String ENDPOINT_EVENT_TYPE = "endpoint";
+
+ /**
+ * Default SSE endpoint path as specified by the MCP transport specification. This
+ * endpoint is used to establish the SSE connection with the server.
+ */
+ private static final String DEFAULT_SSE_ENDPOINT = "/sse";
+
+ /**
+ * Type reference for parsing SSE events containing string data.
+ */
+ private static final ParameterizedTypeReference> SSE_TYPE = new ParameterizedTypeReference<>() {
+ };
+
+ /**
+ * WebClient instance for handling both SSE connections and HTTP POST requests. Used
+ * for establishing the SSE connection and sending outbound messages.
+ */
+ private final WebClient webClient;
+
+ /**
+ * ObjectMapper for serializing outbound messages and deserializing inbound messages.
+ * Handles conversion between JSON-RPC messages and their string representation.
+ */
+ protected ObjectMapper objectMapper;
+
+ /**
+ * Flag indicating if the transport is in the process of shutting down. Used to
+ * prevent new operations during shutdown and handle cleanup gracefully.
+ */
+ private volatile boolean isClosing = false;
+
+ /**
+ * The SSE endpoint URI provided by the server. Used for sending outbound messages via
+ * HTTP POST requests.
+ */
+ private final String sseEndpoint;
+
+ /**
+ * Retry handler for the inbound SSE stream. Implements the retry logic for handling
+ * connection failures and other errors.
+ */
+ private final BiConsumer> inboundRetryHandler = (retrySpec, sink) -> {
+ if (isClosing) {
+ logger.debug("SSE connection closed during shutdown");
+ sink.error(retrySpec.failure());
+ return;
+ }
+ if (retrySpec.failure() instanceof IOException) {
+ logger.debug("Retrying SSE connection after IO error");
+ sink.next(retrySpec);
+ return;
+ }
+ logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage());
+ sink.error(retrySpec.failure());
+ };
+
+ /**
+ * Session factory for creating new sessions
+ */
+ private McpClientSession.Factory sessionFactory;
+
+ /**
+ * Active client session
+ */
+ private McpClientSession session;
+
+ /**
+ * Constructs a new SseClientTransportProvider with the specified WebClient builder.
+ * Uses a default ObjectMapper instance for JSON processing.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @throws IllegalArgumentException if webClientBuilder is null
+ */
+ public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder) {
+ this(webClientBuilder, new ObjectMapper());
+ }
+
+ /**
+ * Constructs a new SseClientTransportProvider with the specified WebClient builder
+ * and ObjectMapper. Initializes both inbound and outbound message processing
+ * pipelines.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @param objectMapper the ObjectMapper to use for JSON processing
+ * @throws IllegalArgumentException if either parameter is null
+ */
+ public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
+ this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT);
+ }
+
+ /**
+ * Constructs a new SseClientTransportProvider with the specified WebClient builder
+ * and ObjectMapper. Initializes both inbound and outbound message processing
+ * pipelines.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @param objectMapper the ObjectMapper to use for JSON processing
+ * @param sseEndpoint the SSE endpoint URI to use for establishing the connection
+ * @throws IllegalArgumentException if either parameter is null
+ */
+ public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper,
+ String sseEndpoint) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty");
+
+ this.objectMapper = objectMapper;
+ this.webClient = webClientBuilder.build();
+ this.sseEndpoint = sseEndpoint;
+ }
+
+ @Override
+ public void setSessionFactory(McpClientSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ @Override
+ public McpClientSession getSession() {
+ if (session != null) {
+ return session;
+ }
+ McpClientTransport mcpClientTransport = new WebFluxSseClientTransport();
+ this.session = sessionFactory.create(mcpClientTransport);
+ return session;
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return session.closeGracefully();
+ }
+
+ /**
+ * Initializes and starts the inbound SSE event processing. Establishes the SSE
+ * connection and sets up event handling for both message and endpoint events.
+ * Includes automatic retry logic for handling transient connection failures.
+ */
+ // visible for tests
+ protected Flux> eventStream() {// @formatter:off
+ return webClient
+ .get()
+ .uri(sseEndpoint)
+ .accept(MediaType.TEXT_EVENT_STREAM)
+ .retrieve()
+ .bodyToFlux(SSE_TYPE)
+ .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler)));
+ } // @formatter:on
+
+ protected class WebFluxSseClientTransport implements McpClientTransport {
+
+ /**
+ * Subscription for the SSE connection handling inbound messages. Used for cleanup
+ * during transport shutdown.
+ */
+ private Disposable inboundSubscription;
+
+ /**
+ * Sink for managing the message endpoint URI provided by the server. Stores the
+ * most recent endpoint URI and makes it available for outbound message
+ * processing.
+ */
+ protected final Sinks.One messageEndpointSink = Sinks.one();
+
+ public WebFluxSseClientTransport() {
+ }
+
+ /**
+ * Establishes a connection to the MCP server using Server-Sent Events (SSE). This
+ * method initiates the SSE connection and sets up the message processing
+ * pipeline.
+ *
+ *
+ * The connection process follows these steps:
+ *
+ * Establishes an SSE connection to the server's /sse endpoint
+ * Waits for the server to send an 'endpoint' event with the message posting
+ * URI
+ * Sets up message handling for incoming JSON-RPC messages
+ *
+ *
+ *
+ * The connection is considered established only after receiving the endpoint
+ * event from the server.
+ * @param handler a function that processes incoming JSON-RPC messages and returns
+ * responses
+ * @return a Mono that completes when the connection is fully established
+ * @throws McpError if there's an error processing SSE events or if an
+ * unrecognized event type is received
+ */
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ return connect();
+ }
+
+ @Override
+ public Mono connect() {
+ Flux> events = eventStream();
+ inboundSubscription = events
+ .concatMap(event -> Mono.just(event).handle((e, s) -> {
+ if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
+ String messageEndpointUri = event.data();
+ if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
+ session.setId(getSessionIdFromUrl(messageEndpointUri));
+ s.complete();
+ }
+ else {
+ // TODO: clarify with the spec if multiple events can be
+ // received
+ s.error(new McpError("Failed to handle SSE endpoint event"));
+ }
+ }
+ else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
+ try {
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper,
+ event.data());
+ session.handle(message).subscribe();
+ s.next(message);
+ }
+ catch (IOException ioException) {
+ s.error(ioException);
+ }
+ }
+ else {
+ s.error(new McpError("Received unrecognized SSE event type: " + event.event()));
+ }
+ }))
+ .subscribe();
+
+ // The connection is established once the server sends the endpoint event
+ return messageEndpointSink.asMono().then();
+ }
+
+ /**
+ * Sends a JSON-RPC message to the server using the endpoint provided during
+ * connection.
+ *
+ *
+ * Messages are sent via HTTP POST requests to the server-provided endpoint URI.
+ * The message is serialized to JSON before transmission. If the transport is in
+ * the process of closing, the message send operation is skipped gracefully.
+ * @param message the JSON-RPC message to send
+ * @return a Mono that completes when the message has been sent successfully
+ * @throws RuntimeException if message serialization fails
+ */
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ // The messageEndpoint is the endpoint URI to send the messages
+ // It is provided by the server as part of the endpoint event
+ return messageEndpointSink.asMono().flatMap(messageEndpointUri -> {
+ if (isClosing) {
+ return Mono.empty();
+ }
+ try {
+ String jsonText = objectMapper.writeValueAsString(message);
+ return webClient.post()
+ .uri(messageEndpointUri)
+ .contentType(MediaType.APPLICATION_JSON)
+ .bodyValue(jsonText)
+ .retrieve()
+ .toBodilessEntity()
+ .doOnSuccess(response -> {
+ logger.debug("Message sent successfully");
+ })
+ .doOnError(error -> {
+ if (!isClosing) {
+ logger.error("Error sending message: {}", error.getMessage());
+ }
+ });
+ }
+ catch (IOException e) {
+ if (!isClosing) {
+ return Mono.error(new RuntimeException("Failed to serialize message", e));
+ }
+ return Mono.empty();
+ }
+ }).then(); // TODO: Consider non-200-ok response
+ }
+
+ /**
+ * Implements graceful shutdown of the transport. Cleans up all resources
+ * including subscriptions and schedulers. Ensures orderly shutdown of both
+ * inbound and outbound message processing.
+ * @return a Mono that completes when shutdown is finished
+ */
+ @Override
+ public Mono closeGracefully() { // @formatter:off
+ return Mono.fromRunnable(() -> {
+ isClosing = true;
+
+ // Dispose of subscriptions
+
+ if (inboundSubscription != null) {
+ inboundSubscription.dispose();
+ }
+
+ })
+ .then()
+ .subscribeOn(Schedulers.boundedElastic());
+ } // @formatter:on
+
+ /**
+ * Unmarshalls data from a generic Object into the specified type using the
+ * configured ObjectMapper.
+ *
+ *
+ * This method is particularly useful when working with JSON-RPC parameters or
+ * result objects that need to be converted to specific Java types. It leverages
+ * Jackson's type conversion capabilities to handle complex object structures.
+ * @param the target type to convert the data into
+ * @param data the source object to convert
+ * @param typeRef the TypeReference describing the target type
+ * @return the unmarshalled object of type T
+ * @throws IllegalArgumentException if the conversion cannot be performed
+ */
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ }
+
+ /**
+ * Creates a new builder for
+ * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider}.
+ * @param webClientBuilder the WebClient.Builder to use for creating the WebClient
+ * instance
+ * @return a new builder instance
+ */
+ public static WebFluxSseClientTransportProvider.Builder builder(WebClient.Builder webClientBuilder) {
+ return new WebFluxSseClientTransportProvider.Builder(webClientBuilder);
+ }
+
+ /**
+ * Builder for
+ * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider}.
+ */
+ public static class Builder {
+
+ private final WebClient.Builder webClientBuilder;
+
+ private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
+
+ private ObjectMapper objectMapper = new ObjectMapper();
+
+ /**
+ * Creates a new builder with the specified WebClient.Builder.
+ * @param webClientBuilder the WebClient.Builder to use
+ */
+ public Builder(WebClient.Builder webClientBuilder) {
+ Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ this.webClientBuilder = webClientBuilder;
+ }
+
+ /**
+ * Sets the SSE endpoint path.
+ * @param sseEndpoint the SSE endpoint path
+ * @return this builder
+ */
+ public WebFluxSseClientTransportProvider.Builder sseEndpoint(String sseEndpoint) {
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ this.sseEndpoint = sseEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets the object mapper for JSON serialization/deserialization.
+ * @param objectMapper the object mapper
+ * @return this builder
+ */
+ public WebFluxSseClientTransportProvider.Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "objectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Builds a new
+ * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider}
+ * instance.
+ * @return a new transport instance
+ */
+ public WebFluxSseClientTransportProvider build() {
+ return new WebFluxSseClientTransportProvider(webClientBuilder, objectMapper, sseEndpoint);
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java
index 76f908b8..02ca8223 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java
@@ -14,8 +14,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
@@ -78,14 +78,14 @@ public void before() {
this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
clientBuilders.put("httpclient",
- McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build()));
clientBuilders.put("webflux",
- McpClient
- .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
- .sseEndpoint(CUSTOM_SSE_ENDPOINT)
- .build()));
+ McpClient.sync(WebFluxSseClientTransportProvider
+ .builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
+ .sseEndpoint(CUSTOM_SSE_ENDPOINT)
+ .build()));
}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
index b43c1449..12d5dfef 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
@@ -6,8 +6,8 @@
import java.time.Duration;
-import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
@@ -15,8 +15,6 @@
import org.springframework.web.reactive.function.client.WebClient;
/**
- * Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}.
- *
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
@@ -31,9 +29,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));
- @Override
- protected McpClientTransport createMcpTransport() {
- return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
+ return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
}
@Override
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
index 66ac8a6d..76c2dc40 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
@@ -6,8 +6,8 @@
import java.time.Duration;
-import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
@@ -15,8 +15,6 @@
import org.springframework.web.reactive.function.client.WebClient;
/**
- * Tests for the {@link McpSyncClient} with {@link WebFluxSseClientTransport}.
- *
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
@@ -32,8 +30,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests {
.waitingFor(Wait.forHttp("/").forStatusCode(404));
@Override
- protected McpClientTransport createMcpTransport() {
- return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
+ return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
}
@Override
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
index c757d3da..87714baf 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
@@ -9,7 +9,10 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
+import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import org.junit.jupiter.api.AfterEach;
@@ -31,8 +34,6 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
- * Tests for the {@link WebFluxSseClientTransport} class.
- *
* @author Christian Tzolov
*/
@Timeout(15)
@@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests {
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));
- private TestSseClientTransport transport;
+ private TestSseClientTransportProvider transportProvider;
+
+ private McpClientTransport transport;
private WebClient.Builder webClientBuilder;
private ObjectMapper objectMapper;
// Test class to access protected methods
- static class TestSseClientTransport extends WebFluxSseClientTransport {
+ static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider {
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer();
- public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
+ public TestSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
super(webClientBuilder, objectMapper);
}
@@ -69,7 +72,7 @@ protected Flux> eventStream() {
}
public String getLastEndpoint() {
- return messageEndpointSink.asMono().block();
+ return ((WebFluxSseClientTransport) getSession().getTransport()).messageEndpointSink.asMono().block();
}
public int getInboundMessageCount() {
@@ -99,7 +102,10 @@ void setUp() {
startContainer();
webClientBuilder = WebClient.builder().baseUrl(host);
objectMapper = new ObjectMapper();
- transport = new TestSseClientTransport(webClientBuilder, objectMapper);
+ transportProvider = new TestSseClientTransportProvider(webClientBuilder, objectMapper);
+ transportProvider.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transport = transportProvider.getSession().getTransport();
transport.connect(Function.identity()).block();
}
@@ -117,15 +123,16 @@ void cleanup() {
@Test
void testEndpointEventHandling() {
- assertThat(transport.getLastEndpoint()).startsWith("/message?");
+ assertThat(transportProvider.getLastEndpoint()).startsWith("/message?");
}
@Test
void constructorValidation() {
- assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class)
+ assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(null))
+ .isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("WebClient.Builder must not be null");
- assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null))
+ assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(webClientBuilder, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("ObjectMapper must not be null");
}
@@ -133,28 +140,45 @@ void constructorValidation() {
@Test
void testBuilderPattern() {
// Test default builder
- WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build();
- assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException();
+ WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider
+ .builder(webClientBuilder)
+ .build();
+ transportProvider1.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transportProvider1.getSession();
+ assertThatCode(() -> transportProvider1.closeGracefully().block()).doesNotThrowAnyException();
// Test builder with custom ObjectMapper
ObjectMapper customMapper = new ObjectMapper();
- WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder)
+ WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider
+ .builder(webClientBuilder)
.objectMapper(customMapper)
.build();
- assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException();
+ transportProvider2.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transportProvider2.getSession();
+ assertThatCode(() -> transportProvider2.closeGracefully().block()).doesNotThrowAnyException();
// Test builder with custom SSE endpoint
- WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder)
+ WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider
+ .builder(webClientBuilder)
.sseEndpoint("/custom-sse")
.build();
- assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException();
+ transportProvider3.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transportProvider3.getSession();
+ assertThatCode(() -> transportProvider3.closeGracefully().block()).doesNotThrowAnyException();
// Test builder with all custom parameters
- WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder)
+ WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider
+ .builder(webClientBuilder)
.objectMapper(customMapper)
.sseEndpoint("/custom-sse")
.build();
- assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException();
+ transportProvider4.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transportProvider4.getSession();
+ assertThatCode(() -> transportProvider4.closeGracefully().block()).doesNotThrowAnyException();
}
@Test
@@ -164,7 +188,7 @@ void testMessageProcessing() {
Map.of("key", "value"));
// Simulate receiving the message
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "test-method",
@@ -176,13 +200,13 @@ void testMessageProcessing() {
// Subscribe to messages and verify
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
- assertThat(transport.getInboundMessageCount()).isEqualTo(1);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}
@Test
void testResponseMessageProcessing() {
// Simulate receiving a response message
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"id": "test-id",
@@ -197,13 +221,13 @@ void testResponseMessageProcessing() {
// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
- assertThat(transport.getInboundMessageCount()).isEqualTo(1);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}
@Test
void testErrorMessageProcessing() {
// Simulate receiving an error message
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"id": "test-id",
@@ -221,13 +245,13 @@ void testErrorMessageProcessing() {
// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
- assertThat(transport.getInboundMessageCount()).isEqualTo(1);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}
@Test
void testNotificationMessageProcessing() {
// Simulate receiving a notification message (no id)
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "update",
@@ -236,7 +260,7 @@ void testNotificationMessageProcessing() {
""");
// Verify the notification was processed
- assertThat(transport.getInboundMessageCount()).isEqualTo(1);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}
@Test
@@ -252,7 +276,7 @@ void testGracefulShutdown() {
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
// Message count should remain 0 after shutdown
- assertThat(transport.getInboundMessageCount()).isEqualTo(0);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(0);
}
@Test
@@ -260,19 +284,23 @@ void testRetryBehavior() {
// Create a WebClient that simulates connection failures
WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host");
- WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build();
+ WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider
+ .builder(failingWebClientBuilder)
+ .build();
+ failingTransportProvider.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
// Verify that the transport attempts to reconnect
StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();
// Clean up
- failingTransport.closeGracefully().block();
+ failingTransportProvider.getSession().getTransport().closeGracefully().block();
}
@Test
void testMultipleMessageProcessing() {
// Simulate receiving multiple messages in sequence
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "method1",
@@ -281,7 +309,7 @@ void testMultipleMessageProcessing() {
}
""");
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "method2",
@@ -301,13 +329,13 @@ void testMultipleMessageProcessing() {
StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete();
// Verify message count
- assertThat(transport.getInboundMessageCount()).isEqualTo(2);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(2);
}
@Test
void testMessageOrderPreservation() {
// Simulate receiving messages in a specific order
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "first",
@@ -316,7 +344,7 @@ void testMessageOrderPreservation() {
}
""");
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "second",
@@ -325,7 +353,7 @@ void testMessageOrderPreservation() {
}
""");
- transport.simulateMessageEvent("""
+ transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "third",
@@ -335,7 +363,7 @@ void testMessageOrderPreservation() {
""");
// Verify message count and order
- assertThat(transport.getInboundMessageCount()).isEqualTo(3);
+ assertThat(transportProvider.getInboundMessageCount()).isEqualTo(3);
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java
index 0e81104b..96664b72 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java
@@ -5,7 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpSchema;
import org.apache.catalina.LifecycleException;
@@ -49,7 +49,7 @@ public void before() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ var clientTransport = HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
.build();
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
index be01365a..bfc69150 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
@@ -11,7 +11,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
@@ -85,7 +85,8 @@ public void before() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build());
+ clientBuilder = McpClient
+ .sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT).build());
// Get the transport from Spring context
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
index 5452c8ea..3e8cfbaa 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
@@ -12,7 +12,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -50,7 +50,7 @@ public abstract class AbstractMcpAsyncClientTests {
private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!";
- abstract protected McpClientTransport createMcpTransport();
+ abstract protected McpClientTransportProvider createMcpClientTransportProvider();
protected void onStart() {
}
@@ -66,11 +66,12 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
- McpAsyncClient client(McpClientTransport transport) {
+ McpAsyncClient client(McpClientTransportProvider transport) {
return client(transport, Function.identity());
}
- McpAsyncClient client(McpClientTransport transport, Function customizer) {
+ McpAsyncClient client(McpClientTransportProvider transport,
+ Function customizer) {
AtomicReference client = new AtomicReference<>();
assertThatCode(() -> {
@@ -85,11 +86,11 @@ McpAsyncClient client(McpClientTransport transport, Function c) {
+ void withClient(McpClientTransportProvider transport, Consumer c) {
withClient(transport, Function.identity(), c);
}
- void withClient(McpClientTransport transport, Function customizer,
+ void withClient(McpClientTransportProvider transport, Function customizer,
Consumer c) {
var client = client(transport, customizer);
try {
@@ -111,7 +112,7 @@ void tearDown() {
}
void verifyInitializationTimeout(Function> operation, String action) {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient))
.expectSubscription()
.thenAwait(getInitializationTimeout())
@@ -126,7 +127,7 @@ void testConstructorWithInvalidArguments() {
assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.async(createMcpClientTransportProvider()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@@ -138,7 +139,7 @@ void testListToolsWithoutInitialization() {
@Test
void testListTools() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null)))
.consumeNextWith(result -> {
assertThat(result.tools()).isNotNull().isNotEmpty();
@@ -158,7 +159,7 @@ void testPingWithoutInitialization() {
@Test
void testPing() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping()))
.expectNextCount(1)
.verifyComplete();
@@ -173,7 +174,7 @@ void testCallToolWithoutInitialization() {
@Test
void testCallTool() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
@@ -189,7 +190,7 @@ void testCallTool() {
@Test
void testCallToolWithInvalidTool() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool",
Map.of("message", ECHO_TEST_MESSAGE));
@@ -207,7 +208,7 @@ void testListResourcesWithoutInitialization() {
@Test
void testListResources() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null)))
.consumeNextWith(resources -> {
assertThat(resources).isNotNull().satisfies(result -> {
@@ -226,7 +227,7 @@ void testListResources() {
@Test
void testMcpAsyncClientState() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
assertThat(mcpAsyncClient).isNotNull();
});
}
@@ -238,7 +239,7 @@ void testListPromptsWithoutInitialization() {
@Test
void testListPrompts() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null)))
.consumeNextWith(prompts -> {
assertThat(prompts).isNotNull().satisfies(result -> {
@@ -263,7 +264,7 @@ void testGetPromptWithoutInitialization() {
@Test
void testGetPrompt() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier
.create(mcpAsyncClient.initialize()
.then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))))
@@ -285,7 +286,7 @@ void testRootsListChangedWithoutInitialization() {
@Test
void testRootsListChanged() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification()))
.verifyComplete();
});
@@ -293,15 +294,15 @@ void testRootsListChanged() {
@Test
void testInitializeWithRootsListProviders() {
- withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
- client -> {
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.roots(new Root("file:///test/path", "test-root")), client -> {
StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete();
});
}
@Test
void testAddRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
Root newRoot = new Root("file:///new/test/path", "new-test-root");
StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete();
});
@@ -309,7 +310,7 @@ void testAddRoot() {
@Test
void testAddRootWithNullValue() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.addRoot(null))
.consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null"))
.verify();
@@ -318,7 +319,7 @@ void testAddRootWithNullValue() {
@Test
void testRemoveRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
Root root = new Root("file:///test/path/to/remove", "root-to-remove");
StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete();
@@ -328,7 +329,7 @@ void testRemoveRoot() {
@Test
void testRemoveNonExistentRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri"))
.consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
.hasMessage("Root with uri 'nonexistent-uri' not found"))
@@ -339,7 +340,7 @@ void testRemoveNonExistentRoot() {
@Test
@Disabled
void testReadResource() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
@@ -359,7 +360,7 @@ void testListResourceTemplatesWithoutInitialization() {
@Test
void testListResourceTemplates() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates()))
.consumeNextWith(result -> {
assertThat(result).isNotNull();
@@ -371,7 +372,7 @@ void testListResourceTemplates() {
// @Test
void testResourceSubscription() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
@@ -394,7 +395,7 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder
.toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
.resourcesChangeConsumer(
@@ -414,7 +415,7 @@ void testInitializeWithSamplingCapability() {
.message("test")
.model("test-model")
.build();
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)),
client -> {
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
@@ -432,8 +433,8 @@ void testInitializeWithAllCapabilities() {
Function> samplingHandler = request -> Mono
.just(CreateMessageResult.builder().message("test").model("test-model").build());
- withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
- client ->
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.capabilities(capabilities).sampling(samplingHandler), client ->
StepVerifier.create(client.initialize()).assertNext(result -> {
assertThat(result).isNotNull();
@@ -453,7 +454,7 @@ void testLoggingLevelsWithoutInitialization() {
@Test
void testLoggingLevels() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier
.create(mcpAsyncClient.initialize()
.thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel)))
@@ -465,7 +466,7 @@ void testLoggingLevels() {
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))),
client -> {
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
@@ -477,7 +478,7 @@ void testLoggingConsumer() {
@Test
void testLoggingWithNullNotification() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.setLoggingLevel(null))
.expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null"))
.verify();
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
index 128441f8..0fe30b9d 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
@@ -11,7 +11,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -49,7 +49,7 @@ public abstract class AbstractMcpSyncClientTests {
private static final String TEST_MESSAGE = "Hello MCP Spring AI!";
- abstract protected McpClientTransport createMcpTransport();
+ abstract protected McpClientTransportProvider createMcpClientTransportProvider();
protected void onStart() {
}
@@ -65,11 +65,12 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
- McpSyncClient client(McpClientTransport transport) {
+ McpSyncClient client(McpClientTransportProvider transport) {
return client(transport, Function.identity());
}
- McpSyncClient client(McpClientTransport transport, Function customizer) {
+ McpSyncClient client(McpClientTransportProvider transport,
+ Function customizer) {
AtomicReference client = new AtomicReference<>();
assertThatCode(() -> {
@@ -84,11 +85,11 @@ McpSyncClient client(McpClientTransport transport, Function c) {
+ void withClient(McpClientTransportProvider transport, Consumer c) {
withClient(transport, Function.identity(), c);
}
- void withClient(McpClientTransport transport, Function customizer,
+ void withClient(McpClientTransportProvider transport, Function customizer,
Consumer c) {
var client = client(transport, customizer);
try {
@@ -120,7 +121,7 @@ void verifyNotificationTimesOut(Consumer operation, String ac
}
void verifyCallTimesOut(Function blockingOperation, String action) {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
// This scheduler is not replaced by virtual time scheduler
Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic");
@@ -147,7 +148,7 @@ void testConstructorWithInvalidArguments() {
assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.sync(createMcpClientTransportProvider()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@@ -159,7 +160,7 @@ void testListToolsWithoutInitialization() {
@Test
void testListTools() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListToolsResult tools = mcpSyncClient.listTools(null);
@@ -181,7 +182,7 @@ void testCallToolsWithoutInitialization() {
@Test
void testCallTools() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)));
@@ -205,7 +206,7 @@ void testPingWithoutInitialization() {
@Test
void testPing() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException();
});
@@ -219,7 +220,7 @@ void testCallToolWithoutInitialization() {
@Test
void testCallTool() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
@@ -234,7 +235,7 @@ void testCallTool() {
@Test
void testCallToolWithInvalidTool() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE));
assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class);
@@ -249,7 +250,7 @@ void testRootsListChangedWithoutInitialization() {
@Test
void testRootsListChanged() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException();
});
@@ -262,7 +263,7 @@ void testListResourcesWithoutInitialization() {
@Test
void testListResources() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourcesResult resources = mcpSyncClient.listResources(null);
@@ -280,15 +281,15 @@ void testListResources() {
@Test
void testClientSessionState() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThat(mcpSyncClient).isNotNull();
});
}
@Test
void testInitializeWithRootsListProviders() {
- withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
- mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.roots(new Root("file:///test/path", "test-root")), mcpSyncClient -> {
assertThatCode(() -> {
mcpSyncClient.initialize();
@@ -299,7 +300,7 @@ void testInitializeWithRootsListProviders() {
@Test
void testAddRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
Root newRoot = new Root("file:///new/test/path", "new-test-root");
assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException();
});
@@ -307,14 +308,14 @@ void testAddRoot() {
@Test
void testAddRootWithNullValue() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null");
});
}
@Test
void testRemoveRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
Root root = new Root("file:///test/path/to/remove", "root-to-remove");
assertThatCode(() -> {
mcpSyncClient.addRoot(root);
@@ -325,7 +326,7 @@ void testRemoveRoot() {
@Test
void testRemoveNonExistentRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri"))
.hasMessageContaining("Root with uri 'nonexistent-uri' not found");
});
@@ -339,7 +340,7 @@ void testReadResourceWithoutInitialization() {
@Test
void testReadResource() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourcesResult resources = mcpSyncClient.listResources(null);
@@ -360,7 +361,7 @@ void testListResourceTemplatesWithoutInitialization() {
@Test
void testListResourceTemplates() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null);
@@ -371,7 +372,7 @@ void testListResourceTemplates() {
// @Test
void testResourceSubscription() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
ListResourcesResult resources = mcpSyncClient.listResources(null);
if (!resources.resources().isEmpty()) {
@@ -394,7 +395,7 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true))
.resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true))
.promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)),
@@ -419,7 +420,7 @@ void testLoggingLevelsWithoutInitialization() {
@Test
void testLoggingLevels() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
// Test all logging levels
for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
@@ -431,7 +432,7 @@ void testLoggingLevels() {
@Test
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout())
+ withClient(createMcpClientTransportProvider(), builder -> builder.requestTimeout(getRequestTimeout())
.loggingConsumer(notification -> logReceived.set(true)), client -> {
assertThatCode(() -> {
client.initialize();
@@ -442,8 +443,9 @@ void testLoggingConsumer() {
@Test
void testLoggingWithNullNotification() {
- withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
- .hasMessageContaining("Logging level must not be null"));
+ withClient(createMcpClientTransportProvider(),
+ mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
+ .hasMessageContaining("Logging level must not be null"));
}
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
index df099836..fd4a51cf 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
@@ -8,16 +8,19 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -30,6 +33,8 @@
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest;
import io.modelcontextprotocol.spec.McpSchema.Root;
+import io.modelcontextprotocol.spec.McpServerSession;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpTransport;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
@@ -100,7 +105,7 @@ public class McpAsyncClient {
/**
* Client capabilities.
*/
- private final McpSchema.ClientCapabilities clientCapabilities;
+ private final ClientCapabilities clientCapabilities;
/**
* Client implementation information.
@@ -135,10 +140,7 @@ public class McpAsyncClient {
*/
private Function> samplingHandler;
- /**
- * Client transport implementation.
- */
- private final McpTransport transport;
+ private final ObjectMapper objectMapper;
/**
* Supported protocol versions.
@@ -148,21 +150,21 @@ public class McpAsyncClient {
/**
* Create a new McpAsyncClient with the given transport and session request-response
* timeout.
- * @param transport the transport to use.
+ * @param mcpClientTransportProvider the transport to use.
* @param requestTimeout the session request-response timeout.
* @param initializationTimeout the max timeout to await for the client-server
* @param features the MCP Client supported features.
*/
- McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout,
- McpClientFeatures.Async features) {
+ McpAsyncClient(McpClientTransportProvider mcpClientTransportProvider, Duration requestTimeout,
+ Duration initializationTimeout, McpClientFeatures.Async features, ObjectMapper objectMapper) {
- Assert.notNull(transport, "Transport must not be null");
+ Assert.notNull(mcpClientTransportProvider, "Transport provider must not be null");
Assert.notNull(requestTimeout, "Request timeout must not be null");
Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
this.clientInfo = features.clientInfo();
this.clientCapabilities = features.clientCapabilities();
- this.transport = transport;
+ this.objectMapper = objectMapper;
this.roots = new ConcurrentHashMap<>(features.roots());
this.initializationTimeout = initializationTimeout;
@@ -228,8 +230,9 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));
- this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers);
-
+ mcpClientTransportProvider.setSessionFactory(
+ (trans) -> new McpClientSession(requestTimeout, trans, requestHandlers, notificationHandlers));
+ this.mcpSession = mcpClientTransportProvider.getSession();
}
/**
@@ -290,6 +293,7 @@ public Mono closeGracefully() {
// --------------------------
// Initialization
// --------------------------
+
/**
* The initialization phase MUST be the first interaction between client and server.
* During this phase, the client and server:
@@ -380,6 +384,7 @@ public Mono ping() {
// --------------------------
// Roots
// --------------------------
+
/**
* Adds a new root to the client's root list.
* @param root The root to add.
@@ -461,9 +466,8 @@ public Mono rootsListChangedNotification() {
private RequestHandler rootsListRequestHandler() {
return params -> {
@SuppressWarnings("unused")
- McpSchema.PaginatedRequest request = transport.unmarshalFrom(params,
- new TypeReference() {
- });
+ PaginatedRequest request = objectMapper.convertValue(params, new TypeReference<>() {
+ });
List roots = this.roots.values().stream().toList();
@@ -476,9 +480,8 @@ private RequestHandler rootsListRequestHandler() {
// --------------------------
private RequestHandler samplingCreateMessageHandler() {
return params -> {
- McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params,
- new TypeReference() {
- });
+ CreateMessageRequest request = objectMapper.convertValue(params, new TypeReference<>() {
+ });
return this.samplingHandler.apply(request);
};
@@ -531,7 +534,7 @@ public Mono listTools(String cursor) {
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
- return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
+ return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
});
}
@@ -588,7 +591,7 @@ public Mono listResources(String cursor) {
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
- return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor),
+ return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new PaginatedRequest(cursor),
LIST_RESOURCES_RESULT_TYPE_REF);
});
}
@@ -648,8 +651,8 @@ public Mono listResourceTemplates(String
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
- return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST,
- new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
+ return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new PaginatedRequest(cursor),
+ LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
});
}
@@ -695,16 +698,16 @@ private NotificationHandler asyncResourcesChangeNotificationHandler(
// --------------------------
// Prompts
// --------------------------
- private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() {
+ private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() {
};
- private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() {
+ private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() {
};
/**
* Retrieves the list of all prompts provided by the server.
* @return A Mono that completes with the list of prompts result.
- * @see McpSchema.ListPromptsResult
+ * @see ListPromptsResult
* @see #getPrompt(GetPromptRequest)
*/
public Mono listPrompts() {
@@ -715,7 +718,7 @@ public Mono listPrompts() {
* Retrieves a paginated list of prompts provided by the server.
* @param cursor Optional pagination cursor from a previous list request
* @return A Mono that completes with the list of prompts result.
- * @see McpSchema.ListPromptsResult
+ * @see ListPromptsResult
* @see #getPrompt(GetPromptRequest)
*/
public Mono listPrompts(String cursor) {
@@ -728,8 +731,8 @@ public Mono listPrompts(String cursor) {
* including all parameters and instructions for generating AI content.
* @param getPromptRequest The request containing the ID of the prompt to retrieve.
* @return A Mono that completes with the prompt result.
- * @see McpSchema.GetPromptRequest
- * @see McpSchema.GetPromptResult
+ * @see GetPromptRequest
+ * @see GetPromptResult
* @see #listPrompts()
*/
public Mono getPrompt(GetPromptRequest getPromptRequest) {
@@ -763,8 +766,8 @@ private NotificationHandler asyncLoggingNotificationHandler(
List>> loggingConsumers) {
return params -> {
- McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params,
- new TypeReference() {
+ LoggingMessageNotification loggingMessageNotification = objectMapper.convertValue(params,
+ new TypeReference<>() {
});
return Flux.fromIterable(loggingConsumers)
@@ -778,7 +781,7 @@ private NotificationHandler asyncLoggingNotificationHandler(
* will only receive log messages at or above the specified severity level.
* @param loggingLevel The minimum logging level to receive.
* @return A Mono that completes when the logging level is set.
- * @see McpSchema.LoggingLevel
+ * @see LoggingLevel
*/
public Mono setLoggingLevel(LoggingLevel loggingLevel) {
if (loggingLevel == null) {
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
index f7b17961..6237cb8b 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
@@ -12,7 +12,9 @@
import java.util.function.Consumer;
import java.util.function.Function;
+import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransport;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -52,7 +54,7 @@
* .requestTimeout(Duration.ofSeconds(5))
* .build();
* }
- *
+ *
* Example of creating a basic asynchronous client:
{@code
* McpClient.async(transport)
* .requestTimeout(Duration.ofSeconds(5))
@@ -114,7 +116,7 @@ public interface McpClient {
* @return A new builder instance for configuring the client
* @throws IllegalArgumentException if transport is null
*/
- static SyncSpec sync(McpClientTransport transport) {
+ static SyncSpec sync(McpClientTransportProvider transport) {
return new SyncSpec(transport);
}
@@ -131,7 +133,7 @@ static SyncSpec sync(McpClientTransport transport) {
* @return A new builder instance for configuring the client
* @throws IllegalArgumentException if transport is null
*/
- static AsyncSpec async(McpClientTransport transport) {
+ static AsyncSpec async(McpClientTransportProvider transport) {
return new AsyncSpec(transport);
}
@@ -153,7 +155,9 @@ static AsyncSpec async(McpClientTransport transport) {
*/
class SyncSpec {
- private final McpClientTransport transport;
+ private final McpClientTransportProvider transportProvider;
+
+ private ObjectMapper objectMapper;
private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
@@ -175,9 +179,9 @@ class SyncSpec {
private Function samplingHandler;
- private SyncSpec(McpClientTransport transport) {
- Assert.notNull(transport, "Transport must not be null");
- this.transport = transport;
+ private SyncSpec(McpClientTransportProvider transportProvider) {
+ Assert.notNull(transportProvider, "Transport must not be null");
+ this.transportProvider = transportProvider;
}
/**
@@ -367,9 +371,10 @@ public McpSyncClient build() {
this.loggingConsumers, this.samplingHandler);
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
+ var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
- return new McpSyncClient(
- new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures));
+ return new McpSyncClient(new McpAsyncClient(this.transportProvider, this.requestTimeout,
+ this.initializationTimeout, asyncFeatures, mapper));
}
}
@@ -392,7 +397,9 @@ public McpSyncClient build() {
*/
class AsyncSpec {
- private final McpClientTransport transport;
+ private final McpClientTransportProvider transportProvider;
+
+ private ObjectMapper objectMapper;
private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
@@ -414,9 +421,9 @@ class AsyncSpec {
private Function> samplingHandler;
- private AsyncSpec(McpClientTransport transport) {
- Assert.notNull(transport, "Transport must not be null");
- this.transport = transport;
+ private AsyncSpec(McpClientTransportProvider transportProvider) {
+ Assert.notNull(transportProvider, "Transport must not be null");
+ this.transportProvider = transportProvider;
}
/**
@@ -603,10 +610,12 @@ public AsyncSpec loggingConsumers(
* @return a new instance of {@link McpAsyncClient}.
*/
public McpAsyncClient build() {
- return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
+ var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
+ return new McpAsyncClient(this.transportProvider, this.requestTimeout, this.initializationTimeout,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers,
- this.loggingConsumers, this.samplingHandler));
+ this.loggingConsumers, this.samplingHandler),
+ mapper);
}
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
deleted file mode 100644
index 632d3844..00000000
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
+++ /dev/null
@@ -1,466 +0,0 @@
-/*
- * Copyright 2024 - 2024 the original author or authors.
- */
-package io.modelcontextprotocol.client.transport;
-
-import java.io.IOException;
-import java.net.URI;
-import java.net.http.HttpClient;
-import java.net.http.HttpRequest;
-import java.net.http.HttpResponse;
-import java.time.Duration;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
-import java.util.function.Function;
-
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
-import io.modelcontextprotocol.spec.McpClientTransport;
-import io.modelcontextprotocol.spec.McpError;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
-import io.modelcontextprotocol.util.Assert;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Mono;
-
-/**
- * Server-Sent Events (SSE) implementation of the
- * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE
- * transport specification, using Java's HttpClient.
- *
- *
- * This transport implementation establishes a bidirectional communication channel between
- * client and server using SSE for server-to-client messages and HTTP POST requests for
- * client-to-server messages. The transport:
- *
- * Establishes an SSE connection to receive server messages
- * Handles endpoint discovery through SSE events
- * Manages message serialization/deserialization using Jackson
- * Provides graceful connection termination
- *
- *
- *
- * The transport supports two types of SSE events:
- *
- * 'endpoint' - Contains the URL for sending client messages
- * 'message' - Contains JSON-RPC message payload
- *
- *
- * @author Christian Tzolov
- * @see io.modelcontextprotocol.spec.McpTransport
- * @see io.modelcontextprotocol.spec.McpClientTransport
- */
-public class HttpClientSseClientTransport implements McpClientTransport {
-
- private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
-
- /** SSE event type for JSON-RPC messages */
- private static final String MESSAGE_EVENT_TYPE = "message";
-
- /** SSE event type for endpoint discovery */
- private static final String ENDPOINT_EVENT_TYPE = "endpoint";
-
- /** Default SSE endpoint path */
- private static final String DEFAULT_SSE_ENDPOINT = "/sse";
-
- /** Base URI for the MCP server */
- private final String baseUri;
-
- /** SSE endpoint path */
- private final String sseEndpoint;
-
- /** SSE client for handling server-sent events. Uses the /sse endpoint */
- private final FlowSseClient sseClient;
-
- /**
- * HTTP client for sending messages to the server. Uses HTTP POST over the message
- * endpoint
- */
- private final HttpClient httpClient;
-
- /** HTTP request builder for building requests to send messages to the server */
- private final HttpRequest.Builder requestBuilder;
-
- /** JSON object mapper for message serialization/deserialization */
- protected ObjectMapper objectMapper;
-
- /** Flag indicating if the transport is in closing state */
- private volatile boolean isClosing = false;
-
- /** Latch for coordinating endpoint discovery */
- private final CountDownLatch closeLatch = new CountDownLatch(1);
-
- /** Holds the discovered message endpoint URL */
- private final AtomicReference messageEndpoint = new AtomicReference<>();
-
- /** Holds the SSE connection future */
- private final AtomicReference> connectionFuture = new AtomicReference<>();
-
- /**
- * Creates a new transport instance with default HTTP client and object mapper.
- * @param baseUri the base URI of the MCP server
- * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
- * constructor will be removed in future versions.
- */
- @Deprecated(forRemoval = true)
- public HttpClientSseClientTransport(String baseUri) {
- this(HttpClient.newBuilder(), baseUri, new ObjectMapper());
- }
-
- /**
- * Creates a new transport instance with custom HTTP client builder and object mapper.
- * @param clientBuilder the HTTP client builder to use
- * @param baseUri the base URI of the MCP server
- * @param objectMapper the object mapper for JSON serialization/deserialization
- * @throws IllegalArgumentException if objectMapper or clientBuilder is null
- * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
- * constructor will be removed in future versions.
- */
- @Deprecated(forRemoval = true)
- public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) {
- this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper);
- }
-
- /**
- * Creates a new transport instance with custom HTTP client builder and object mapper.
- * @param clientBuilder the HTTP client builder to use
- * @param baseUri the base URI of the MCP server
- * @param sseEndpoint the SSE endpoint path
- * @param objectMapper the object mapper for JSON serialization/deserialization
- * @throws IllegalArgumentException if objectMapper or clientBuilder is null
- * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
- * constructor will be removed in future versions.
- */
- @Deprecated(forRemoval = true)
- public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint,
- ObjectMapper objectMapper) {
- this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper);
- }
-
- /**
- * Creates a new transport instance with custom HTTP client builder, object mapper,
- * and headers.
- * @param clientBuilder the HTTP client builder to use
- * @param requestBuilder the HTTP request builder to use
- * @param baseUri the base URI of the MCP server
- * @param sseEndpoint the SSE endpoint path
- * @param objectMapper the object mapper for JSON serialization/deserialization
- * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
- * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This
- * constructor will be removed in future versions.
- */
- @Deprecated(forRemoval = true)
- public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder,
- String baseUri, String sseEndpoint, ObjectMapper objectMapper) {
- this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint,
- objectMapper);
- }
-
- /**
- * Creates a new transport instance with custom HTTP client builder, object mapper,
- * and headers.
- * @param httpClient the HTTP client to use
- * @param requestBuilder the HTTP request builder to use
- * @param baseUri the base URI of the MCP server
- * @param sseEndpoint the SSE endpoint path
- * @param objectMapper the object mapper for JSON serialization/deserialization
- * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
- */
- HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
- String sseEndpoint, ObjectMapper objectMapper) {
- Assert.notNull(objectMapper, "ObjectMapper must not be null");
- Assert.hasText(baseUri, "baseUri must not be empty");
- Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
- Assert.notNull(httpClient, "httpClient must not be null");
- Assert.notNull(requestBuilder, "requestBuilder must not be null");
- this.baseUri = baseUri;
- this.sseEndpoint = sseEndpoint;
- this.objectMapper = objectMapper;
- this.httpClient = httpClient;
- this.requestBuilder = requestBuilder;
-
- this.sseClient = new FlowSseClient(this.httpClient, requestBuilder);
- }
-
- /**
- * Creates a new builder for {@link HttpClientSseClientTransport}.
- * @param baseUri the base URI of the MCP server
- * @return a new builder instance
- */
- public static Builder builder(String baseUri) {
- return new Builder().baseUri(baseUri);
- }
-
- /**
- * Builder for {@link HttpClientSseClientTransport}.
- */
- public static class Builder {
-
- private String baseUri;
-
- private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
-
- private HttpClient.Builder clientBuilder = HttpClient.newBuilder()
- .version(HttpClient.Version.HTTP_1_1)
- .connectTimeout(Duration.ofSeconds(10));
-
- private ObjectMapper objectMapper = new ObjectMapper();
-
- private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
- .header("Content-Type", "application/json");
-
- /**
- * Creates a new builder instance.
- */
- Builder() {
- // Default constructor
- }
-
- /**
- * Creates a new builder with the specified base URI.
- * @param baseUri the base URI of the MCP server
- * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead.
- * This constructor is deprecated and will be removed or made {@code protected} or
- * {@code private} in a future release.
- */
- @Deprecated(forRemoval = true)
- public Builder(String baseUri) {
- Assert.hasText(baseUri, "baseUri must not be empty");
- this.baseUri = baseUri;
- }
-
- /**
- * Sets the base URI.
- * @param baseUri the base URI
- * @return this builder
- */
- Builder baseUri(String baseUri) {
- Assert.hasText(baseUri, "baseUri must not be empty");
- this.baseUri = baseUri;
- return this;
- }
-
- /**
- * Sets the SSE endpoint path.
- * @param sseEndpoint the SSE endpoint path
- * @return this builder
- */
- public Builder sseEndpoint(String sseEndpoint) {
- Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
- this.sseEndpoint = sseEndpoint;
- return this;
- }
-
- /**
- * Sets the HTTP client builder.
- * @param clientBuilder the HTTP client builder
- * @return this builder
- */
- public Builder clientBuilder(HttpClient.Builder clientBuilder) {
- Assert.notNull(clientBuilder, "clientBuilder must not be null");
- this.clientBuilder = clientBuilder;
- return this;
- }
-
- /**
- * Customizes the HTTP client builder.
- * @param clientCustomizer the consumer to customize the HTTP client builder
- * @return this builder
- */
- public Builder customizeClient(final Consumer clientCustomizer) {
- Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
- clientCustomizer.accept(clientBuilder);
- return this;
- }
-
- /**
- * Sets the HTTP request builder.
- * @param requestBuilder the HTTP request builder
- * @return this builder
- */
- public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
- Assert.notNull(requestBuilder, "requestBuilder must not be null");
- this.requestBuilder = requestBuilder;
- return this;
- }
-
- /**
- * Customizes the HTTP client builder.
- * @param requestCustomizer the consumer to customize the HTTP request builder
- * @return this builder
- */
- public Builder customizeRequest(final Consumer requestCustomizer) {
- Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
- requestCustomizer.accept(requestBuilder);
- return this;
- }
-
- /**
- * Sets the object mapper for JSON serialization/deserialization.
- * @param objectMapper the object mapper
- * @return this builder
- */
- public Builder objectMapper(ObjectMapper objectMapper) {
- Assert.notNull(objectMapper, "objectMapper must not be null");
- this.objectMapper = objectMapper;
- return this;
- }
-
- /**
- * Builds a new {@link HttpClientSseClientTransport} instance.
- * @return a new transport instance
- */
- public HttpClientSseClientTransport build() {
- return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint,
- objectMapper);
- }
-
- }
-
- /**
- * Establishes the SSE connection with the server and sets up message handling.
- *
- *
- * This method:
- *
- * Initiates the SSE connection
- * Handles endpoint discovery events
- * Processes incoming JSON-RPC messages
- *
- * @param handler the function to process received JSON-RPC messages
- * @return a Mono that completes when the connection is established
- */
- @Override
- public Mono connect(Function, Mono> handler) {
- CompletableFuture future = new CompletableFuture<>();
- connectionFuture.set(future);
-
- sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() {
- @Override
- public void onEvent(SseEvent event) {
- if (isClosing) {
- return;
- }
-
- try {
- if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
- String endpoint = event.data();
- messageEndpoint.set(endpoint);
- closeLatch.countDown();
- future.complete(null);
- }
- else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
- JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
- handler.apply(Mono.just(message)).subscribe();
- }
- else {
- logger.error("Received unrecognized SSE event type: {}", event.type());
- }
- }
- catch (IOException e) {
- logger.error("Error processing SSE event", e);
- future.completeExceptionally(e);
- }
- }
-
- @Override
- public void onError(Throwable error) {
- if (!isClosing) {
- logger.error("SSE connection error", error);
- future.completeExceptionally(error);
- }
- }
- });
-
- return Mono.fromFuture(future);
- }
-
- /**
- * Sends a JSON-RPC message to the server.
- *
- *
- * This method waits for the message endpoint to be discovered before sending the
- * message. The message is serialized to JSON and sent as an HTTP POST request.
- * @param message the JSON-RPC message to send
- * @return a Mono that completes when the message is sent
- * @throws McpError if the message endpoint is not available or the wait times out
- */
- @Override
- public Mono sendMessage(JSONRPCMessage message) {
- if (isClosing) {
- return Mono.empty();
- }
-
- try {
- if (!closeLatch.await(10, TimeUnit.SECONDS)) {
- return Mono.error(new McpError("Failed to wait for the message endpoint"));
- }
- }
- catch (InterruptedException e) {
- return Mono.error(new McpError("Failed to wait for the message endpoint"));
- }
-
- String endpoint = messageEndpoint.get();
- if (endpoint == null) {
- return Mono.error(new McpError("No message endpoint available"));
- }
-
- try {
- String jsonText = this.objectMapper.writeValueAsString(message);
- HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint))
- .POST(HttpRequest.BodyPublishers.ofString(jsonText))
- .build();
-
- return Mono.fromFuture(
- httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> {
- if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202
- && response.statusCode() != 206) {
- logger.error("Error sending message: {}", response.statusCode());
- }
- }));
- }
- catch (IOException e) {
- if (!isClosing) {
- return Mono.error(new RuntimeException("Failed to serialize message", e));
- }
- return Mono.empty();
- }
- }
-
- /**
- * Gracefully closes the transport connection.
- *
- *
- * Sets the closing flag and cancels any pending connection future. This prevents new
- * messages from being sent and allows ongoing operations to complete.
- * @return a Mono that completes when the closing process is initiated
- */
- @Override
- public Mono closeGracefully() {
- return Mono.fromRunnable(() -> {
- isClosing = true;
- CompletableFuture future = connectionFuture.get();
- if (future != null && !future.isDone()) {
- future.cancel(true);
- }
- });
- }
-
- /**
- * Unmarshal data to the specified type using the configured object mapper.
- * @param data the data to unmarshal
- * @param typeRef the type reference for the target type
- * @param the target type
- * @return the unmarshalled object
- */
- @Override
- public T unmarshalFrom(Object data, TypeReference typeRef) {
- return this.objectMapper.convertValue(data, typeRef);
- }
-
-}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java
new file mode 100644
index 00000000..84f68481
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java
@@ -0,0 +1,518 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+package io.modelcontextprotocol.client.transport;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
+import io.modelcontextprotocol.util.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Mono;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.time.Duration;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+import static io.modelcontextprotocol.util.Utils.getSessionIdFromUrl;
+
+/**
+ * Server-Sent Events (SSE) implementation of the
+ * {@link io.modelcontextprotocol.spec.McpClientTransportProvider} that provides SSE
+ * client transport that follows the MCP HTTP with SSE transport specification, using
+ * Java's HttpClient.
+ *
+ * @author Christian Tzolov
+ * @author Jermaine Hua
+ * @see io.modelcontextprotocol.spec.McpClientTransportProvider
+ */
+public class HttpClientSseClientTransportProvider implements McpClientTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransportProvider.class);
+
+ /**
+ * SSE event type for JSON-RPC messages
+ */
+ private static final String MESSAGE_EVENT_TYPE = "message";
+
+ /**
+ * SSE event type for endpoint discovery
+ */
+ private static final String ENDPOINT_EVENT_TYPE = "endpoint";
+
+ /**
+ * Default SSE endpoint path
+ */
+ private static final String DEFAULT_SSE_ENDPOINT = "/sse";
+
+ /**
+ * Base URI for the MCP server
+ */
+ private final String baseUri;
+
+ /**
+ * SSE endpoint path
+ */
+ private final String sseEndpoint;
+
+ /**
+ * JSON object mapper for message serialization/deserialization
+ */
+ protected ObjectMapper objectMapper;
+
+ /**
+ * Flag indicating if the transport is in closing state
+ */
+ private volatile boolean isClosing = false;
+
+ /**
+ * Latch for coordinating endpoint discovery
+ */
+ private final CountDownLatch closeLatch = new CountDownLatch(1);
+
+ /**
+ * HTTP request builder for building requests to send messages to the server
+ */
+ private final HttpRequest.Builder requestBuilder;
+
+ private final HttpClient.Builder clientBuilder;
+
+ /**
+ * Session factory for creating new sessions
+ */
+ private McpClientSession.Factory sessionFactory;
+
+ /**
+ * Active client session
+ */
+ private McpClientSession session;
+
+ /**
+ * Creates a new transport instance with default HTTP client and object mapper.
+ * @param baseUri the base URI of the MCP server
+ * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)}
+ * instead. This constructor will be removed in future versions.
+ */
+ @Deprecated(forRemoval = true)
+ public HttpClientSseClientTransportProvider(String baseUri) {
+ this(HttpClient.newBuilder(), baseUri, new ObjectMapper());
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder and object mapper.
+ * @param clientBuilder the HTTP client builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper or clientBuilder is null
+ * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)}
+ * instead. This constructor will be removed in future versions.
+ */
+ @Deprecated(forRemoval = true)
+ public HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, String baseUri,
+ ObjectMapper objectMapper) {
+ this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper);
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder and object mapper.
+ * @param clientBuilder the HTTP client builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param sseEndpoint the SSE endpoint path
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper or clientBuilder is null
+ * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)}
+ * instead. This constructor will be removed in future versions.
+ */
+ @Deprecated(forRemoval = true)
+ public HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint,
+ ObjectMapper objectMapper) {
+ this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper);
+ }
+
+ /**
+ * Creates a new transport instance with custom HTTP client builder, object mapper,
+ * and headers.
+ * @param clientBuilder the HTTP client builder to use
+ * @param requestBuilder the HTTP request builder to use
+ * @param baseUri the base URI of the MCP server
+ * @param sseEndpoint the SSE endpoint path
+ * @param objectMapper the object mapper for JSON serialization/deserialization
+ * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
+ * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)}
+ * instead. This constructor will be removed in future versions.
+ */
+ HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder,
+ String baseUri, String sseEndpoint, ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ Assert.hasText(baseUri, "baseUri must not be empty");
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ Assert.notNull(requestBuilder, "requestBuilder must not be null");
+ this.baseUri = baseUri;
+ this.sseEndpoint = sseEndpoint;
+ this.objectMapper = objectMapper;
+ this.requestBuilder = requestBuilder;
+ this.clientBuilder = clientBuilder;
+ }
+
+ /**
+ * Creates a new builder for {@link HttpClientSseClientTransport}.
+ * @param baseUri the base URI of the MCP server
+ * @return a new builder instance
+ */
+ public static Builder builder(String baseUri) {
+ return new Builder().baseUri(baseUri);
+ }
+
+ /**
+ * Builder for {@link HttpClientSseClientTransportProvider}.
+ */
+ public static class Builder {
+
+ private String baseUri;
+
+ private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
+
+ private HttpClient.Builder clientBuilder = HttpClient.newBuilder()
+ .version(HttpClient.Version.HTTP_1_1)
+ .connectTimeout(Duration.ofSeconds(10));
+
+ private ObjectMapper objectMapper = new ObjectMapper();
+
+ private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
+ .header("Content-Type", "application/json");
+
+ /**
+ * Creates a new builder instance.
+ */
+ Builder() {
+ // Default constructor
+ }
+
+ /**
+ * Creates a new builder with the specified base URI.
+ * @param baseUri the base URI of the MCP server
+ * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead.
+ * This constructor is deprecated and will be removed or made {@code protected} or
+ * {@code private} in a future release.
+ */
+ @Deprecated(forRemoval = true)
+ public Builder(String baseUri) {
+ Assert.hasText(baseUri, "baseUri must not be empty");
+ this.baseUri = baseUri;
+ }
+
+ /**
+ * Sets the base URI.
+ * @param baseUri the base URI
+ * @return this builder
+ */
+ Builder baseUri(String baseUri) {
+ Assert.hasText(baseUri, "baseUri must not be empty");
+ this.baseUri = baseUri;
+ return this;
+ }
+
+ /**
+ * Sets the SSE endpoint path.
+ * @param sseEndpoint the SSE endpoint path
+ * @return this builder
+ */
+ public Builder sseEndpoint(String sseEndpoint) {
+ Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
+ this.sseEndpoint = sseEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets the HTTP client builder.
+ * @param clientBuilder the HTTP client builder
+ * @return this builder
+ */
+ public Builder clientBuilder(HttpClient.Builder clientBuilder) {
+ Assert.notNull(clientBuilder, "clientBuilder must not be null");
+ this.clientBuilder = clientBuilder;
+ return this;
+ }
+
+ /**
+ * Customizes the HTTP client builder.
+ * @param clientCustomizer the consumer to customize the HTTP client builder
+ * @return this builder
+ */
+ public Builder customizeClient(final Consumer clientCustomizer) {
+ Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
+ clientCustomizer.accept(clientBuilder);
+ return this;
+ }
+
+ /**
+ * Sets the HTTP request builder.
+ * @param requestBuilder the HTTP request builder
+ * @return this builder
+ */
+ public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
+ Assert.notNull(requestBuilder, "requestBuilder must not be null");
+ this.requestBuilder = requestBuilder;
+ return this;
+ }
+
+ /**
+ * Customizes the HTTP client builder.
+ * @param requestCustomizer the consumer to customize the HTTP request builder
+ * @return this builder
+ */
+ public Builder customizeRequest(final Consumer requestCustomizer) {
+ Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
+ requestCustomizer.accept(requestBuilder);
+ return this;
+ }
+
+ /**
+ * Sets the object mapper for JSON serialization/deserialization.
+ * @param objectMapper the object mapper
+ * @return this builder
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "objectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Builds a new {@link HttpClientSseClientTransportProvider} instance.
+ * @return a new transport instance
+ */
+ public HttpClientSseClientTransportProvider build() {
+ return new HttpClientSseClientTransportProvider(clientBuilder, requestBuilder, baseUri, sseEndpoint,
+ objectMapper);
+ }
+
+ }
+
+ @Override
+ public void setSessionFactory(McpClientSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ /**
+ * Gracefully closes the transport connection.
+ *
+ *
+ * Sets the closing flag and cancels any pending connection future. This prevents new
+ * messages from being sent and allows ongoing operations to complete.
+ * @return a Mono that completes when the closing process is initiated
+ */
+ @Override
+ public Mono closeGracefully() {
+ isClosing = true;
+ return session.closeGracefully().then();
+ }
+
+ @Override
+ public McpClientSession getSession() {
+ if (session != null) {
+ return session;
+ }
+
+ HttpClient httpClient = clientBuilder.build();
+ FlowSseClient sseClient = new FlowSseClient(httpClient, requestBuilder);
+ McpClientTransport mcpClientTransport = new HttpClientSseClientTransport(sseClient, httpClient);
+ session = sessionFactory.create(mcpClientTransport);
+ return session;
+ }
+
+ private class HttpClientSseClientTransport implements McpClientTransport {
+
+ /**
+ * SSE client for handling server-sent events. Uses the /sse endpoint
+ */
+ private final FlowSseClient sseClient;
+
+ /**
+ * HTTP client for sending messages to the server. Uses HTTP POST over the message
+ * endpoint
+ */
+ private final HttpClient httpClient;
+
+ /**
+ * Holds the discovered message endpoint URL
+ */
+ private final AtomicReference messageEndpoint = new AtomicReference<>();
+
+ /**
+ * Holds the SSE connection future
+ */
+ private final AtomicReference> connectionFuture = new AtomicReference<>();
+
+ public HttpClientSseClientTransport(FlowSseClient sseClient, HttpClient httpClient) {
+ this.sseClient = sseClient;
+ this.httpClient = httpClient;
+ }
+
+ /**
+ * Establishes the SSE connection with the server and sets up message handling.
+ *
+ *
+ * This method:
+ *
+ * Initiates the SSE connection
+ * Handles endpoint discovery events
+ * Processes incoming JSON-RPC messages
+ *
+ * @param handler the function to process received JSON-RPC messages
+ * @return a Mono that completes when the connection is established
+ */
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ return connect();
+ }
+
+ @Override
+ public Mono connect() {
+ CompletableFuture future = new CompletableFuture<>();
+ connectionFuture.set(future);
+ sseClient.subscribe(baseUri + sseEndpoint, new FlowSseClient.SseEventHandler() {
+ @Override
+ public void onEvent(SseEvent event) {
+ if (isClosing) {
+ return;
+ }
+
+ try {
+ if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
+ String endpoint = event.data();
+ messageEndpoint.set(endpoint);
+ closeLatch.countDown();
+ future.complete(null);
+ }
+ else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
+ JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
+ session.handle(message).subscribe();
+ }
+ else {
+ logger.error("Received unrecognized SSE event type: {}", event.type());
+ }
+ }
+ catch (IOException e) {
+ logger.error("Error processing SSE event", e);
+ future.completeExceptionally(e);
+ }
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ if (!isClosing) {
+ logger.error("SSE connection error", error);
+ future.completeExceptionally(error);
+ }
+ }
+ });
+
+ return Mono.fromFuture(future);
+ }
+
+ /**
+ * Sends a JSON-RPC message to the server.
+ *
+ *
+ * This method waits for the message endpoint to be discovered before sending the
+ * message. The message is serialized to JSON and sent as an HTTP POST request.
+ * @param message the JSON-RPC message to send
+ * @return a Mono that completes when the message is sent
+ * @throws McpError if the message endpoint is not available or the wait times out
+ */
+ @Override
+ public Mono sendMessage(JSONRPCMessage message) {
+ if (isClosing) {
+ return Mono.empty();
+ }
+
+ try {
+ if (!closeLatch.await(10, TimeUnit.SECONDS)) {
+ return Mono.error(new McpError("Failed to wait for the message endpoint"));
+ }
+ }
+ catch (InterruptedException e) {
+ return Mono.error(new McpError("Failed to wait for the message endpoint"));
+ }
+
+ String endpoint = messageEndpoint.get();
+ if (endpoint == null) {
+ return Mono.error(new McpError("No message endpoint available"));
+ }
+
+ try {
+ String jsonText = objectMapper.writeValueAsString(message);
+ HttpRequest request = HttpRequest.newBuilder()
+ .uri(URI.create(baseUri + endpoint))
+ .header("Content-Type", "application/json")
+ .POST(HttpRequest.BodyPublishers.ofString(jsonText))
+ .build();
+
+ return Mono.fromFuture(
+ httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> {
+ if (response.statusCode() != 200 && response.statusCode() != 201
+ && response.statusCode() != 202 && response.statusCode() != 206) {
+ logger.error("Error sending message: {}", response.statusCode());
+ }
+ }));
+ }
+ catch (IOException e) {
+ if (!isClosing) {
+ return Mono.error(new RuntimeException("Failed to serialize message", e));
+ }
+ return Mono.empty();
+ }
+ }
+
+ /**
+ * Gracefully closes the transport connection.
+ *
+ *
+ * Sets the closing flag and cancels any pending connection future. This prevents
+ * new messages from being sent and allows ongoing operations to complete.
+ * @return a Mono that completes when the closing process is initiated
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ isClosing = true;
+ CompletableFuture future = connectionFuture.get();
+ if (future != null && !future.isDone()) {
+ future.cancel(true);
+ }
+ });
+ }
+
+ /**
+ * Unmarshal data to the specified type using the configured object mapper.
+ * @param data the data to unmarshal
+ * @param typeRef the type reference for the target type
+ * @param the target type
+ * @return the unmarshalled object
+ */
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ }
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java
deleted file mode 100644
index 9d71cbb4..00000000
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java
+++ /dev/null
@@ -1,394 +0,0 @@
-/*
- * Copyright 2024-2024 the original author or authors.
- */
-
-package io.modelcontextprotocol.client.transport;
-
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStreamReader;
-import java.nio.charset.StandardCharsets;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.Executors;
-import java.util.function.Consumer;
-import java.util.function.Function;
-
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.McpClientTransport;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
-import io.modelcontextprotocol.util.Assert;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import reactor.core.publisher.Sinks;
-import reactor.core.scheduler.Scheduler;
-import reactor.core.scheduler.Schedulers;
-
-/**
- * Implementation of the MCP Stdio transport that communicates with a server process using
- * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC
- * messages over stdin/stdout, with errors and debug information sent to stderr.
- *
- * @author Christian Tzolov
- * @author Dariusz Jędrzejczyk
- */
-public class StdioClientTransport implements McpClientTransport {
-
- private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class);
-
- private final Sinks.Many inboundSink;
-
- private final Sinks.Many outboundSink;
-
- /** The server process being communicated with */
- private Process process;
-
- private ObjectMapper objectMapper;
-
- /** Scheduler for handling inbound messages from the server process */
- private Scheduler inboundScheduler;
-
- /** Scheduler for handling outbound messages to the server process */
- private Scheduler outboundScheduler;
-
- /** Scheduler for handling error messages from the server process */
- private Scheduler errorScheduler;
-
- /** Parameters for configuring and starting the server process */
- private final ServerParameters params;
-
- private final Sinks.Many errorSink;
-
- private volatile boolean isClosing = false;
-
- // visible for tests
- private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error);
-
- /**
- * Creates a new StdioClientTransport with the specified parameters and default
- * ObjectMapper.
- * @param params The parameters for configuring the server process
- */
- public StdioClientTransport(ServerParameters params) {
- this(params, new ObjectMapper());
- }
-
- /**
- * Creates a new StdioClientTransport with the specified parameters and ObjectMapper.
- * @param params The parameters for configuring the server process
- * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
- */
- public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) {
- Assert.notNull(params, "The params can not be null");
- Assert.notNull(objectMapper, "The ObjectMapper can not be null");
-
- this.inboundSink = Sinks.many().unicast().onBackpressureBuffer();
- this.outboundSink = Sinks.many().unicast().onBackpressureBuffer();
-
- this.params = params;
-
- this.objectMapper = objectMapper;
-
- this.errorSink = Sinks.many().unicast().onBackpressureBuffer();
-
- // Start threads
- this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound");
- this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound");
- this.errorScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "error");
- }
-
- /**
- * Starts the server process and initializes the message processing streams. This
- * method sets up the process with the configured command, arguments, and environment,
- * then starts the inbound, outbound, and error processing threads.
- * @throws RuntimeException if the process fails to start or if the process streams
- * are null
- */
- @Override
- public Mono connect(Function, Mono> handler) {
- return Mono.fromRunnable(() -> {
- handleIncomingMessages(handler);
- handleIncomingErrors();
-
- // Prepare command and environment
- List fullCommand = new ArrayList<>();
- fullCommand.add(params.getCommand());
- fullCommand.addAll(params.getArgs());
-
- ProcessBuilder processBuilder = this.getProcessBuilder();
- processBuilder.command(fullCommand);
- processBuilder.environment().putAll(params.getEnv());
-
- // Start the process
- try {
- this.process = processBuilder.start();
- }
- catch (IOException e) {
- throw new RuntimeException("Failed to start process with command: " + fullCommand, e);
- }
-
- // Validate process streams
- if (this.process.getInputStream() == null || process.getOutputStream() == null) {
- this.process.destroy();
- throw new RuntimeException("Process input or output stream is null");
- }
-
- // Start threads
- startInboundProcessing();
- startOutboundProcessing();
- startErrorProcessing();
- }).subscribeOn(Schedulers.boundedElastic());
- }
-
- /**
- * Creates and returns a new ProcessBuilder instance. Protected to allow overriding in
- * tests.
- * @return A new ProcessBuilder instance
- */
- protected ProcessBuilder getProcessBuilder() {
- return new ProcessBuilder();
- }
-
- /**
- * Sets the handler for processing transport-level errors.
- *
- *
- * The provided handler will be called when errors occur during transport operations,
- * such as connection failures or protocol violations.
- *
- * @param errorHandler a consumer that processes error messages
- */
- public void setStdErrorHandler(Consumer errorHandler) {
- this.stdErrorHandler = errorHandler;
- }
-
- /**
- * Waits for the server process to exit.
- * @throws RuntimeException if the process is interrupted while waiting
- */
- public void awaitForExit() {
- try {
- this.process.waitFor();
- }
- catch (InterruptedException e) {
- throw new RuntimeException("Process interrupted", e);
- }
- }
-
- /**
- * Starts the error processing thread that reads from the process's error stream.
- * Error messages are logged and emitted to the error sink.
- */
- private void startErrorProcessing() {
- this.errorScheduler.schedule(() -> {
- try (BufferedReader processErrorReader = new BufferedReader(
- new InputStreamReader(process.getErrorStream()))) {
- String line;
- while (!isClosing && (line = processErrorReader.readLine()) != null) {
- try {
- if (!this.errorSink.tryEmitNext(line).isSuccess()) {
- if (!isClosing) {
- logger.error("Failed to emit error message");
- }
- break;
- }
- }
- catch (Exception e) {
- if (!isClosing) {
- logger.error("Error processing error message", e);
- }
- break;
- }
- }
- }
- catch (IOException e) {
- if (!isClosing) {
- logger.error("Error reading from error stream", e);
- }
- }
- finally {
- isClosing = true;
- errorSink.tryEmitComplete();
- }
- });
- }
-
- private void handleIncomingMessages(Function, Mono> inboundMessageHandler) {
- this.inboundSink.asFlux()
- .flatMap(message -> Mono.just(message)
- .transform(inboundMessageHandler)
- .contextWrite(ctx -> ctx.put("observation", "myObservation")))
- .subscribe();
- }
-
- private void handleIncomingErrors() {
- this.errorSink.asFlux().subscribe(e -> {
- this.stdErrorHandler.accept(e);
- });
- }
-
- @Override
- public Mono sendMessage(JSONRPCMessage message) {
- if (this.outboundSink.tryEmitNext(message).isSuccess()) {
- // TODO: essentially we could reschedule ourselves in some time and make
- // another attempt with the already read data but pause reading until
- // success
- // In this approach we delegate the retry and the backpressure onto the
- // caller. This might be enough for most cases.
- return Mono.empty();
- }
- else {
- return Mono.error(new RuntimeException("Failed to enqueue message"));
- }
- }
-
- /**
- * Starts the inbound processing thread that reads JSON-RPC messages from the
- * process's input stream. Messages are deserialized and emitted to the inbound sink.
- */
- private void startInboundProcessing() {
- this.inboundScheduler.schedule(() -> {
- try (BufferedReader processReader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
- String line;
- while (!isClosing && (line = processReader.readLine()) != null) {
- try {
- JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line);
- if (!this.inboundSink.tryEmitNext(message).isSuccess()) {
- if (!isClosing) {
- logger.error("Failed to enqueue inbound message: {}", message);
- }
- break;
- }
- }
- catch (Exception e) {
- if (!isClosing) {
- logger.error("Error processing inbound message for line: " + line, e);
- }
- break;
- }
- }
- }
- catch (IOException e) {
- if (!isClosing) {
- logger.error("Error reading from input stream", e);
- }
- }
- finally {
- isClosing = true;
- inboundSink.tryEmitComplete();
- }
- });
- }
-
- /**
- * Starts the outbound processing thread that writes JSON-RPC messages to the
- * process's output stream. Messages are serialized to JSON and written with a newline
- * delimiter.
- */
- private void startOutboundProcessing() {
- this.handleOutbound(messages -> messages
- // this bit is important since writes come from user threads, and we
- // want to ensure that the actual writing happens on a dedicated thread
- .publishOn(outboundScheduler)
- .handle((message, s) -> {
- if (message != null && !isClosing) {
- try {
- String jsonMessage = objectMapper.writeValueAsString(message);
- // Escape any embedded newlines in the JSON message as per spec:
- // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio
- // - Messages are delimited by newlines, and MUST NOT contain
- // embedded newlines.
- jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n");
-
- var os = this.process.getOutputStream();
- synchronized (os) {
- os.write(jsonMessage.getBytes(StandardCharsets.UTF_8));
- os.write("\n".getBytes(StandardCharsets.UTF_8));
- os.flush();
- }
- s.next(message);
- }
- catch (IOException e) {
- s.error(new RuntimeException(e));
- }
- }
- }));
- }
-
- protected void handleOutbound(Function, Flux> outboundConsumer) {
- outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> {
- isClosing = true;
- outboundSink.tryEmitComplete();
- }).doOnError(e -> {
- if (!isClosing) {
- logger.error("Error in outbound processing", e);
- isClosing = true;
- outboundSink.tryEmitComplete();
- }
- }).subscribe();
- }
-
- /**
- * Gracefully closes the transport by destroying the process and disposing of the
- * schedulers. This method sends a TERM signal to the process and waits for it to exit
- * before cleaning up resources.
- * @return A Mono that completes when the transport is closed
- */
- @Override
- public Mono closeGracefully() {
- return Mono.fromRunnable(() -> {
- isClosing = true;
- logger.debug("Initiating graceful shutdown");
- }).then(Mono.defer(() -> {
- // First complete all sinks to stop accepting new messages
- inboundSink.tryEmitComplete();
- outboundSink.tryEmitComplete();
- errorSink.tryEmitComplete();
-
- // Give a short time for any pending messages to be processed
- return Mono.delay(Duration.ofMillis(100));
- })).then(Mono.defer(() -> {
- logger.debug("Sending TERM to process");
- if (this.process != null) {
- this.process.destroy();
- return Mono.fromFuture(process.onExit());
- }
- else {
- logger.warn("Process not started");
- return Mono.empty();
- }
- })).doOnNext(process -> {
- if (process.exitValue() != 0) {
- logger.warn("Process terminated with code " + process.exitValue());
- }
- }).then(Mono.fromRunnable(() -> {
- try {
- // The Threads are blocked on readLine so disposeGracefully would not
- // interrupt them, therefore we issue an async hard dispose.
- inboundScheduler.dispose();
- errorScheduler.dispose();
- outboundScheduler.dispose();
-
- logger.debug("Graceful shutdown completed");
- }
- catch (Exception e) {
- logger.error("Error during graceful shutdown", e);
- }
- })).then().subscribeOn(Schedulers.boundedElastic());
- }
-
- public Sinks.Many getErrorSink() {
- return this.errorSink;
- }
-
- @Override
- public T unmarshalFrom(Object data, TypeReference typeRef) {
- return this.objectMapper.convertValue(data, typeRef);
- }
-
-}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java
new file mode 100644
index 00000000..f8ec7abb
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java
@@ -0,0 +1,459 @@
+package io.modelcontextprotocol.client.transport;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.util.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Executors;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+/**
+ * Implementation of a transport provider that provides the MCP Stdio transport that
+ * communicates with a server process using standard input/output streams. Messages are
+ * exchanged as newline-delimited JSON-RPC messages over stdin/stdout, and errors and
+ * debug information are sent to stderr.
+ *
+ * @author Christian Tzolov
+ * @author Dariusz Jędrzejczyk
+ * @author Jermaine Hua
+ * @see io.modelcontextprotocol.spec.McpClientTransportProvider
+ */
+public class StdioClientTransportProvider implements McpClientTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(StdioClientTransportProvider.class);
+
+ private final ObjectMapper objectMapper;
+
+ /** Parameters for configuring and starting the server process */
+ private final ServerParameters params;
+
+ private volatile boolean isClosing = false;
+
+ // visible for tests
+ private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error);
+
+ private final Sinks.Many errorSink;
+
+ /**
+ * Session factory for creating new sessions
+ */
+ private McpClientSession.Factory sessionFactory;
+
+ /**
+ * Active client session
+ */
+ private McpClientSession session;
+
+ /**
+ * Creates a new StdioClientTransportProvider with the specified parameters and
+ * default ObjectMapper.
+ * @param params The parameters for configuring the server process
+ */
+ public StdioClientTransportProvider(ServerParameters params) {
+ this(params, new ObjectMapper());
+ }
+
+ /**
+ * Creates a new StdioClientTransportProvider with the specified parameters and
+ * ObjectMapper.
+ * @param params The parameters for configuring the server process
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ */
+ public StdioClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) {
+ Assert.notNull(params, "The params can not be null");
+ Assert.notNull(objectMapper, "The ObjectMapper can not be null");
+ this.objectMapper = objectMapper;
+ this.params = params;
+ this.errorSink = Sinks.many().unicast().onBackpressureBuffer();
+ }
+
+ @Override
+ public void setSessionFactory(McpClientSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ @Override
+ public McpClientSession getSession() {
+ if (session != null) {
+ return session;
+ }
+ McpClientTransport mcpClientTransport = new StdioClientTransport(params);
+ this.session = sessionFactory.create(mcpClientTransport);
+ return session;
+ }
+
+ /**
+ * Sets the handler for processing transport-level errors.
+ *
+ *
+ * The provided handler will be called when errors occur during transport operations,
+ * such as connection failures or protocol violations.
+ *
+ * @param errorHandler a consumer that processes error messages
+ */
+ public void setStdErrorHandler(Consumer errorHandler) {
+ stdErrorHandler = errorHandler;
+ }
+
+ public Sinks.Many getErrorSink() {
+ return errorSink;
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return session.closeGracefully();
+ }
+
+ private class StdioClientTransport implements McpClientTransport {
+
+ private final Sinks.Many inboundSink;
+
+ private final Sinks.Many outboundSink;
+
+ /** The server process being communicated with */
+ private Process process;
+
+ /** Scheduler for handling inbound messages from the server process */
+ private Scheduler inboundScheduler;
+
+ /** Scheduler for handling outbound messages to the server process */
+ private Scheduler outboundScheduler;
+
+ /** Scheduler for handling error messages from the server process */
+ private Scheduler errorScheduler;
+
+ /**
+ * Creates a new StdioClientTransport with the specified parameters and default
+ * ObjectMapper.
+ * @param params The parameters for configuring the server process
+ */
+ public StdioClientTransport(ServerParameters params) {
+ this(params, new ObjectMapper());
+ }
+
+ /**
+ * Creates a new StdioClientTransport with the specified parameters and
+ * ObjectMapper.
+ * @param params The parameters for configuring the server process
+ * @param objectMapper The ObjectMapper to use for JSON
+ * serialization/deserialization
+ */
+ public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) {
+ Assert.notNull(params, "The params can not be null");
+ Assert.notNull(objectMapper, "The ObjectMapper can not be null");
+
+ this.inboundSink = Sinks.many().unicast().onBackpressureBuffer();
+ this.outboundSink = Sinks.many().unicast().onBackpressureBuffer();
+
+ // Start threads
+ this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound");
+ this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound");
+ this.errorScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "error");
+ }
+
+ /**
+ * Starts the server process and initializes the message processing streams. This
+ * method sets up the process with the configured command, arguments, and
+ * environment, then starts the inbound, outbound, and error processing threads.
+ * @throws RuntimeException if the process fails to start or if the process
+ * streams are null
+ */
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ return connect();
+ }
+
+ @Override
+ public Mono connect() {
+ return Mono.fromRunnable(() -> {
+ this.inboundSink.asFlux()
+ .flatMap(message -> session.handle(message)
+ .contextWrite(ctx -> ctx.put("observation", "myObservation")))
+ .subscribe();
+ handleIncomingErrors();
+
+ // Prepare command and environment
+ List fullCommand = new ArrayList<>();
+ fullCommand.add(params.getCommand());
+ fullCommand.addAll(params.getArgs());
+
+ ProcessBuilder processBuilder = this.getProcessBuilder();
+ processBuilder.command(fullCommand);
+ processBuilder.environment().putAll(params.getEnv());
+
+ // Start the process
+ try {
+ this.process = processBuilder.start();
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Failed to start process with command: " + fullCommand, e);
+ }
+
+ // Validate process streams
+ if (this.process.getInputStream() == null || process.getOutputStream() == null) {
+ this.process.destroy();
+ throw new RuntimeException("Process input or output stream is null");
+ }
+
+ // Start threads
+ startInboundProcessing();
+ startOutboundProcessing();
+ startErrorProcessing();
+ }).subscribeOn(Schedulers.boundedElastic());
+ }
+
+ /**
+ * Creates and returns a new ProcessBuilder instance. Protected to allow
+ * overriding in tests.
+ * @return A new ProcessBuilder instance
+ */
+ protected ProcessBuilder getProcessBuilder() {
+ return new ProcessBuilder();
+ }
+
+ /**
+ * Waits for the server process to exit.
+ * @throws RuntimeException if the process is interrupted while waiting
+ */
+ public void awaitForExit() {
+ try {
+ this.process.waitFor();
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException("Process interrupted", e);
+ }
+ }
+
+ /**
+ * Starts the error processing thread that reads from the process's error stream.
+ * Error messages are logged and emitted to the error sink.
+ */
+ private void startErrorProcessing() {
+ this.errorScheduler.schedule(() -> {
+ try (BufferedReader processErrorReader = new BufferedReader(
+ new InputStreamReader(process.getErrorStream()))) {
+ String line;
+ while (!isClosing && (line = processErrorReader.readLine()) != null) {
+ try {
+ if (!errorSink.tryEmitNext(line).isSuccess()) {
+ if (!isClosing) {
+ logger.error("Failed to emit error message");
+ }
+ break;
+ }
+ }
+ catch (Exception e) {
+ if (!isClosing) {
+ logger.error("Error processing error message", e);
+ }
+ break;
+ }
+ }
+ }
+ catch (IOException e) {
+ if (!isClosing) {
+ logger.error("Error reading from error stream", e);
+ }
+ }
+ finally {
+ isClosing = true;
+ errorSink.tryEmitComplete();
+ }
+ });
+ }
+
+ private void handleIncomingMessages(
+ Function, Mono> inboundMessageHandler) {
+ this.inboundSink.asFlux()
+ .flatMap(
+ message -> session.handle(message).contextWrite(ctx -> ctx.put("observation", "myObservation")))
+ .subscribe();
+ }
+
+ private void handleIncomingErrors() {
+ errorSink.asFlux().subscribe(e -> {
+ stdErrorHandler.accept(e);
+ });
+ }
+
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ if (this.outboundSink.tryEmitNext(message).isSuccess()) {
+ // TODO: essentially we could reschedule ourselves in some time and make
+ // another attempt with the already read data but pause reading until
+ // success
+ // In this approach we delegate the retry and the backpressure onto the
+ // caller. This might be enough for most cases.
+ return Mono.empty();
+ }
+ else {
+ return Mono.error(new RuntimeException("Failed to enqueue message"));
+ }
+ }
+
+ /**
+ * Starts the inbound processing thread that reads JSON-RPC messages from the
+ * process's input stream. Messages are deserialized and emitted to the inbound
+ * sink.
+ */
+ private void startInboundProcessing() {
+ this.inboundScheduler.schedule(() -> {
+ try (BufferedReader processReader = new BufferedReader(
+ new InputStreamReader(process.getInputStream()))) {
+ String line;
+ while (!isClosing && (line = processReader.readLine()) != null) {
+ try {
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, line);
+ if (!this.inboundSink.tryEmitNext(message).isSuccess()) {
+ if (!isClosing) {
+ logger.error("Failed to enqueue inbound message: {}", message);
+ }
+ break;
+ }
+ }
+ catch (Exception e) {
+ if (!isClosing) {
+ logger.error("Error processing inbound message for line: " + line, e);
+ }
+ break;
+ }
+ }
+ }
+ catch (IOException e) {
+ if (!isClosing) {
+ logger.error("Error reading from input stream", e);
+ }
+ }
+ finally {
+ isClosing = true;
+ inboundSink.tryEmitComplete();
+ }
+ });
+ }
+
+ /**
+ * Starts the outbound processing thread that writes JSON-RPC messages to the
+ * process's output stream. Messages are serialized to JSON and written with a
+ * newline delimiter.
+ */
+ private void startOutboundProcessing() {
+ this.handleOutbound(messages -> messages
+ // this bit is important since writes come from user threads and we
+ // want to ensure that the actual writing happens on a dedicated thread
+ .publishOn(outboundScheduler)
+ .handle((message, s) -> {
+ if (message != null && !isClosing) {
+ try {
+ String jsonMessage = objectMapper.writeValueAsString(message);
+ // Escape any embedded newlines in the JSON message as per
+ // spec:
+ // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio
+ // - Messages are delimited by newlines, and MUST NOT contain
+ // embedded newlines.
+ jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n");
+
+ var os = this.process.getOutputStream();
+ synchronized (os) {
+ os.write(jsonMessage.getBytes(StandardCharsets.UTF_8));
+ os.write("\n".getBytes(StandardCharsets.UTF_8));
+ os.flush();
+ }
+ s.next(message);
+ }
+ catch (IOException e) {
+ s.error(new RuntimeException(e));
+ }
+ }
+ }));
+ }
+
+ protected void handleOutbound(
+ Function, Flux> outboundConsumer) {
+ outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> {
+ isClosing = true;
+ outboundSink.tryEmitComplete();
+ }).doOnError(e -> {
+ if (!isClosing) {
+ logger.error("Error in outbound processing", e);
+ isClosing = true;
+ outboundSink.tryEmitComplete();
+ }
+ }).subscribe();
+ }
+
+ /**
+ * Gracefully closes the transport by destroying the process and disposing of the
+ * schedulers. This method sends a TERM signal to the process and waits for it to
+ * exit before cleaning up resources.
+ * @return A Mono that completes when the transport is closed
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ isClosing = true;
+ logger.debug("Initiating graceful shutdown");
+ }).then(Mono.defer(() -> {
+ // First complete all sinks to stop accepting new messages
+ inboundSink.tryEmitComplete();
+ outboundSink.tryEmitComplete();
+ errorSink.tryEmitComplete();
+
+ // Give a short time for any pending messages to be processed
+ return Mono.delay(Duration.ofMillis(100));
+ })).then(Mono.defer(() -> {
+ logger.debug("Sending TERM to process");
+ if (this.process != null) {
+ this.process.destroy();
+ return Mono.fromFuture(process.onExit());
+ }
+ else {
+ logger.warn("Process not started");
+ return Mono.empty();
+ }
+ })).doOnNext(process -> {
+ if (process.exitValue() != 0) {
+ logger.warn("Process terminated with code " + process.exitValue());
+ }
+ }).then(Mono.fromRunnable(() -> {
+ try {
+ // The Threads are blocked on readLine so disposeGracefully would not
+ // interrupt them, therefore we issue an async hard dispose.
+ inboundScheduler.dispose();
+ errorScheduler.dispose();
+ outboundScheduler.dispose();
+
+ logger.debug("Graceful shutdown completed");
+ }
+ catch (Exception e) {
+ logger.error("Error during graceful shutdown", e);
+ }
+ })).then().subscribeOn(Schedulers.boundedElastic());
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ }
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
index 0895e02b..d0974582 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
@@ -17,6 +17,7 @@
import reactor.core.Disposable;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
+import reactor.core.scheduler.Schedulers;
/**
* Default implementation of the MCP (Model Context Protocol) session that manages
@@ -40,6 +41,10 @@ public class McpClientSession implements McpSession {
/** Logger for this class */
private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class);
+ private String id;
+
+ private static final String DEFAULT_SESSION_ID = "-1";
+
/** Duration to wait for request responses before timing out */
private final Duration requestTimeout;
@@ -97,21 +102,30 @@ public interface NotificationHandler {
}
+ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
+ Map> requestHandlers, Map notificationHandlers) {
+
+ this(DEFAULT_SESSION_ID, requestTimeout, transport, requestHandlers, notificationHandlers);
+ }
+
/**
* Creates a new McpClientSession with the specified configuration and handlers.
+ * @param id Unique identifier for the session
* @param requestTimeout Duration to wait for responses
* @param transport Transport implementation for message exchange
* @param requestHandlers Map of method names to request handlers
* @param notificationHandlers Map of method names to notification handlers
*/
- public McpClientSession(Duration requestTimeout, McpClientTransport transport,
+ public McpClientSession(String id, Duration requestTimeout, McpClientTransport transport,
Map> requestHandlers, Map notificationHandlers) {
+ Assert.notNull(id, "The id can not be null");
Assert.notNull(requestTimeout, "The requestTimeout can not be null");
Assert.notNull(transport, "The transport can not be null");
Assert.notNull(requestHandlers, "The requestHandlers can not be null");
Assert.notNull(notificationHandlers, "The notificationHandlers can not be null");
+ this.id = id;
this.requestTimeout = requestTimeout;
this.transport = transport;
this.requestHandlers.putAll(requestHandlers);
@@ -122,7 +136,11 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
// Observation associated with the individual message - it can be used to
// create child Observation and emit it together with the message to the
// consumer
- this.connection = this.transport.connect(mono -> mono.doOnNext(message -> {
+ this.connection = this.transport.connect().subscribe();
+ }
+
+ public Mono handle(McpSchema.JSONRPCMessage message) {
+ return Mono.defer(() -> {
if (message instanceof McpSchema.JSONRPCResponse response) {
logger.debug("Received Response: {}", response);
var sink = pendingResponses.remove(response.id());
@@ -132,23 +150,27 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
else {
sink.success(response);
}
+ return Mono.empty();
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
- handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(),
- error -> {
- var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
- null, new McpSchema.JSONRPCResponse.JSONRPCError(
- McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null));
- transport.sendMessage(errorResponse).subscribe();
- });
+ return handleIncomingRequest(request).onErrorResume(error -> {
+ var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
+ new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
+ error.getMessage(), null));
+ return this.transport.sendMessage(errorResponse).then(Mono.empty());
+ }).flatMap(this.transport::sendMessage);
}
else if (message instanceof McpSchema.JSONRPCNotification notification) {
logger.debug("Received notification: {}", notification);
- handleIncomingNotification(notification).subscribe(null,
- error -> logger.error("Error handling notification: {}", error.getMessage()));
+ return handleIncomingNotification(notification)
+ .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage()));
}
- })).subscribe();
+ else {
+ logger.warn("Received unknown message type: {}", message);
+ return Mono.empty();
+ }
+ });
}
/**
@@ -285,4 +307,32 @@ public void close() {
transport.close();
}
+ public String getId() {
+ return id;
+ }
+
+ public void setId(String id) {
+ this.id = id;
+ }
+
+ public McpClientTransport getTransport() {
+ return transport;
+ }
+
+ /**
+ * Factory for creating client sessions which delegate to a provided 1:1 transport
+ * with a connected client.
+ */
+ @FunctionalInterface
+ public interface Factory {
+
+ /**
+ * Creates a new 1:1 representation of the client-server interaction.
+ * @param sessionTransport the transport to use for communication with the server.
+ * @return a new client session.
+ */
+ McpClientSession create(McpClientTransport sessionTransport);
+
+ }
+
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java
index f2909124..2c6af589 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java
@@ -12,9 +12,14 @@
*
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
+ * @author Jermaine Hua
*/
public interface McpClientTransport extends McpTransport {
Mono connect(Function, Mono> handler);
+ default Mono connect() {
+ return connect(mono -> mono);
+ };
+
}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java
new file mode 100644
index 00000000..3cfb9af1
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java
@@ -0,0 +1,39 @@
+package io.modelcontextprotocol.spec;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * @author Jermaine Hua
+ */
+public interface McpClientTransportProvider {
+
+ /**
+ * Sets the session factory that will be used to create sessions for new clients. An
+ * implementation of the MCP server MUST call this method before any MCP interactions
+ * take place.
+ * @param sessionFactory the session factory to be used for initiating client sessions
+ */
+ void setSessionFactory(McpClientSession.Factory sessionFactory);
+
+ /**
+ * Get the active session.
+ * @return active client session
+ */
+ McpClientSession getSession();
+
+ /**
+ * Immediately closes all the transports with connected clients and releases any
+ * associated resources.
+ */
+ default void close() {
+ this.closeGracefully().subscribe();
+ }
+
+ /**
+ * Gracefully closes all the transports with connected clients and releases any
+ * associated resources asynchronously.
+ * @return a {@link Mono} that completes when the connections have been closed.
+ */
+ Mono closeGracefully();
+
+}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
index 0f799ca0..536c65aa 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
@@ -4,7 +4,11 @@
package io.modelcontextprotocol.util;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.URLDecoder;
import java.util.Collection;
+import java.util.HashMap;
import java.util.Map;
import reactor.util.annotation.Nullable;
@@ -52,4 +56,27 @@ public static boolean isEmpty(@Nullable Map, ?> map) {
return (map == null || map.isEmpty());
}
+ public static String getSessionIdFromUrl(String urlStr) {
+ URI uri;
+ try {
+ uri = new URI(urlStr);
+ }
+ catch (URISyntaxException e) {
+ return null;
+ }
+ String query = uri.getQuery();
+ if (query == null) {
+ return null;
+ }
+ Map params = new HashMap<>();
+ String[] pairs = query.split("&");
+ for (String pair : pairs) {
+ int idx = pair.indexOf("=");
+ String key = (idx > 0) ? pair.substring(0, idx) : pair;
+ String value = (idx > 0 && pair.length() > idx + 1) ? pair.substring(idx + 1) : null;
+ params.put(key, value);
+ }
+ return params.get("sessionId");
+ }
+
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java
new file mode 100644
index 00000000..bf15283d
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java
@@ -0,0 +1,140 @@
+package io.modelcontextprotocol;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+/**
+ * @author Jermaine Hua
+ */
+public class MockMcpClientTransportProvider implements McpClientTransportProvider {
+
+ private McpClientSession.Factory sessionFactory;
+
+ private McpClientSession session;
+
+ private final BiConsumer interceptor;
+
+ private MockMcpClientTransport transport;
+
+ public MockMcpClientTransportProvider() {
+ this((t, msg) -> {
+ });
+ }
+
+ public MockMcpClientTransportProvider(BiConsumer interceptor) {
+ this.transport = new MockMcpClientTransport();
+ this.interceptor = interceptor;
+ }
+
+ public MockMcpClientTransport getTransport() {
+ return transport;
+ }
+
+ @Override
+ public void setSessionFactory(McpClientSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ @Override
+ public McpClientSession getSession() {
+ if (session != null) {
+ return session;
+ }
+ session = sessionFactory.create(transport);
+ return session;
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return session.closeGracefully();
+ }
+
+ public class MockMcpClientTransport implements McpClientTransport {
+
+ private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer();
+
+ private final List sent = new ArrayList<>();
+
+ public MockMcpClientTransport() {
+ }
+
+ public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
+ if (inbound.tryEmitNext(message).isFailure()) {
+ throw new RuntimeException("Failed to process incoming message " + message);
+ }
+ }
+
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ sent.add(message);
+ interceptor.accept(this, message);
+ return Mono.empty();
+ }
+
+ public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() {
+ return (McpSchema.JSONRPCRequest) getLastSentMessage();
+ }
+
+ public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() {
+ return (McpSchema.JSONRPCNotification) getLastSentMessage();
+ }
+
+ public McpSchema.JSONRPCMessage getLastSentMessage() {
+ return !sent.isEmpty() ? sent.get(sent.size() - 1) : null;
+ }
+
+ private volatile boolean connected = false;
+
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ if (connected) {
+ return Mono.error(new IllegalStateException("Already connected"));
+ }
+ connected = true;
+ return inbound.asFlux()
+ .flatMap(message -> session.handle(message))
+ .doFinally(signal -> connected = false)
+ .then();
+ }
+
+ @Override
+ public Mono connect() {
+ if (connected) {
+ return Mono.error(new IllegalStateException("Already connected"));
+ }
+ connected = true;
+ return inbound.asFlux()
+ .flatMap(message -> session.handle(message))
+ .doFinally(signal -> connected = false)
+ .then();
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.defer(() -> {
+ connected = false;
+ inbound.tryEmitComplete();
+ // Wait for all subscribers to complete
+ return Mono.empty();
+ });
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return new ObjectMapper().convertValue(data, typeRef);
+ }
+
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
index 72b409af..2a63ae9f 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
@@ -4,15 +4,7 @@
package io.modelcontextprotocol.client;
-import java.time.Duration;
-import java.util.Map;
-import java.util.Objects;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
-import java.util.function.Function;
-
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -35,6 +27,14 @@
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
+import java.time.Duration;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -51,7 +51,7 @@ public abstract class AbstractMcpAsyncClientTests {
private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!";
- abstract protected McpClientTransport createMcpTransport();
+ abstract protected McpClientTransportProvider createMcpClientTransportProvider();
protected void onStart() {
}
@@ -67,15 +67,16 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
- McpAsyncClient client(McpClientTransport transport) {
- return client(transport, Function.identity());
+ McpAsyncClient client(McpClientTransportProvider transportProvider) {
+ return client(transportProvider, Function.identity());
}
- McpAsyncClient client(McpClientTransport transport, Function customizer) {
+ McpAsyncClient client(McpClientTransportProvider transportProvider,
+ Function customizer) {
AtomicReference client = new AtomicReference<>();
assertThatCode(() -> {
- McpClient.AsyncSpec builder = McpClient.async(transport)
+ McpClient.AsyncSpec builder = McpClient.async(transportProvider)
.requestTimeout(getRequestTimeout())
.initializationTimeout(getInitializationTimeout())
.capabilities(ClientCapabilities.builder().roots(true).build());
@@ -86,13 +87,13 @@ McpAsyncClient client(McpClientTransport transport, Function c) {
- withClient(transport, Function.identity(), c);
+ void withClient(McpClientTransportProvider transportProvider, Consumer c) {
+ withClient(transportProvider, Function.identity(), c);
}
- void withClient(McpClientTransport transport, Function customizer,
- Consumer c) {
- var client = client(transport, customizer);
+ void withClient(McpClientTransportProvider transportProvider,
+ Function customizer, Consumer c) {
+ var client = client(transportProvider, customizer);
try {
c.accept(client);
}
@@ -112,7 +113,7 @@ void tearDown() {
}
void verifyInitializationTimeout(Function> operation, String action) {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient))
.expectSubscription()
.thenAwait(getInitializationTimeout())
@@ -127,7 +128,7 @@ void testConstructorWithInvalidArguments() {
assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.async(createMcpClientTransportProvider()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@@ -139,7 +140,7 @@ void testListToolsWithoutInitialization() {
@Test
void testListTools() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null)))
.consumeNextWith(result -> {
assertThat(result.tools()).isNotNull().isNotEmpty();
@@ -159,7 +160,7 @@ void testPingWithoutInitialization() {
@Test
void testPing() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping()))
.expectNextCount(1)
.verifyComplete();
@@ -174,7 +175,7 @@ void testCallToolWithoutInitialization() {
@Test
void testCallTool() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest)))
@@ -190,7 +191,7 @@ void testCallTool() {
@Test
void testCallToolWithInvalidTool() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool",
Map.of("message", ECHO_TEST_MESSAGE));
@@ -208,7 +209,7 @@ void testListResourcesWithoutInitialization() {
@Test
void testListResources() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null)))
.consumeNextWith(resources -> {
assertThat(resources).isNotNull().satisfies(result -> {
@@ -227,7 +228,7 @@ void testListResources() {
@Test
void testMcpAsyncClientState() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
assertThat(mcpAsyncClient).isNotNull();
});
}
@@ -239,7 +240,7 @@ void testListPromptsWithoutInitialization() {
@Test
void testListPrompts() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null)))
.consumeNextWith(prompts -> {
assertThat(prompts).isNotNull().satisfies(result -> {
@@ -264,7 +265,7 @@ void testGetPromptWithoutInitialization() {
@Test
void testGetPrompt() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier
.create(mcpAsyncClient.initialize()
.then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))))
@@ -286,7 +287,7 @@ void testRootsListChangedWithoutInitialization() {
@Test
void testRootsListChanged() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification()))
.verifyComplete();
});
@@ -294,15 +295,15 @@ void testRootsListChanged() {
@Test
void testInitializeWithRootsListProviders() {
- withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
- client -> {
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.roots(new Root("file:///test/path", "test-root")), client -> {
StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete();
});
}
@Test
void testAddRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
Root newRoot = new Root("file:///new/test/path", "new-test-root");
StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete();
});
@@ -310,7 +311,7 @@ void testAddRoot() {
@Test
void testAddRootWithNullValue() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.addRoot(null))
.consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null"))
.verify();
@@ -319,7 +320,7 @@ void testAddRootWithNullValue() {
@Test
void testRemoveRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
Root root = new Root("file:///test/path/to/remove", "root-to-remove");
StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete();
@@ -329,7 +330,7 @@ void testRemoveRoot() {
@Test
void testRemoveNonExistentRoot() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri"))
.consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
.hasMessage("Root with uri 'nonexistent-uri' not found"))
@@ -340,7 +341,7 @@ void testRemoveNonExistentRoot() {
@Test
@Disabled
void testReadResource() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
@@ -360,7 +361,7 @@ void testListResourceTemplatesWithoutInitialization() {
@Test
void testListResourceTemplates() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates()))
.consumeNextWith(result -> {
assertThat(result).isNotNull();
@@ -372,7 +373,7 @@ void testListResourceTemplates() {
// @Test
void testResourceSubscription() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
@@ -395,7 +396,7 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder
.toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
.resourcesChangeConsumer(
@@ -415,7 +416,7 @@ void testInitializeWithSamplingCapability() {
.message("test")
.model("test-model")
.build();
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)),
client -> {
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
@@ -433,8 +434,8 @@ void testInitializeWithAllCapabilities() {
Function> samplingHandler = request -> Mono
.just(CreateMessageResult.builder().message("test").model("test-model").build());
- withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
- client ->
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.capabilities(capabilities).sampling(samplingHandler), client ->
StepVerifier.create(client.initialize()).assertNext(result -> {
assertThat(result).isNotNull();
@@ -454,7 +455,7 @@ void testLoggingLevelsWithoutInitialization() {
@Test
void testLoggingLevels() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier
.create(mcpAsyncClient.initialize()
.thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel)))
@@ -466,7 +467,7 @@ void testLoggingLevels() {
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))),
client -> {
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
@@ -478,7 +479,7 @@ void testLoggingConsumer() {
@Test
void testLoggingWithNullNotification() {
- withClient(createMcpTransport(), mcpAsyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.setLoggingLevel(null))
.expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null"))
.verify();
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
index 24c161eb..6af2e2cd 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
@@ -11,7 +11,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
@@ -50,7 +50,7 @@ public abstract class AbstractMcpSyncClientTests {
private static final String TEST_MESSAGE = "Hello MCP Spring AI!";
- abstract protected McpClientTransport createMcpTransport();
+ abstract protected McpClientTransportProvider createMcpClientTransportProvider();
protected void onStart() {
}
@@ -66,15 +66,16 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(2);
}
- McpSyncClient client(McpClientTransport transport) {
- return client(transport, Function.identity());
+ McpSyncClient client(McpClientTransportProvider transportProvider) {
+ return client(transportProvider, Function.identity());
}
- McpSyncClient client(McpClientTransport transport, Function customizer) {
+ McpSyncClient client(McpClientTransportProvider transportProvider,
+ Function customizer) {
AtomicReference client = new AtomicReference<>();
assertThatCode(() -> {
- McpClient.SyncSpec builder = McpClient.sync(transport)
+ McpClient.SyncSpec builder = McpClient.sync(transportProvider)
.requestTimeout(getRequestTimeout())
.initializationTimeout(getInitializationTimeout())
.capabilities(ClientCapabilities.builder().roots(true).build());
@@ -85,13 +86,13 @@ McpSyncClient client(McpClientTransport transport, Function c) {
- withClient(transport, Function.identity(), c);
+ void withClient(McpClientTransportProvider transportProvider, Consumer c) {
+ withClient(transportProvider, Function.identity(), c);
}
- void withClient(McpClientTransport transport, Function customizer,
- Consumer c) {
- var client = client(transport, customizer);
+ void withClient(McpClientTransportProvider transportProvider,
+ Function customizer, Consumer c) {
+ var client = client(transportProvider, customizer);
try {
c.accept(client);
}
@@ -121,7 +122,7 @@ void verifyNotificationTimesOut(Consumer operation, String ac
}
void verifyCallTimesOut(Function blockingOperation, String action) {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
// This scheduler is not replaced by virtual time scheduler
Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic");
@@ -148,7 +149,7 @@ void testConstructorWithInvalidArguments() {
assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("Transport must not be null");
- assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build())
+ assertThatThrownBy(() -> McpClient.sync(createMcpClientTransportProvider()).requestTimeout(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Request timeout must not be null");
}
@@ -160,7 +161,7 @@ void testListToolsWithoutInitialization() {
@Test
void testListTools() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListToolsResult tools = mcpSyncClient.listTools(null);
@@ -182,7 +183,7 @@ void testCallToolsWithoutInitialization() {
@Test
void testCallTools() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)));
@@ -206,7 +207,7 @@ void testPingWithoutInitialization() {
@Test
void testPing() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException();
});
@@ -220,7 +221,7 @@ void testCallToolWithoutInitialization() {
@Test
void testCallTool() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
@@ -235,7 +236,7 @@ void testCallTool() {
@Test
void testCallToolWithInvalidTool() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE));
assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class);
@@ -250,7 +251,7 @@ void testRootsListChangedWithoutInitialization() {
@Test
void testRootsListChanged() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException();
});
@@ -263,7 +264,7 @@ void testListResourcesWithoutInitialization() {
@Test
void testListResources() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourcesResult resources = mcpSyncClient.listResources(null);
@@ -281,15 +282,15 @@ void testListResources() {
@Test
void testClientSessionState() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThat(mcpSyncClient).isNotNull();
});
}
@Test
void testInitializeWithRootsListProviders() {
- withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")),
- mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(),
+ builder -> builder.roots(new Root("file:///test/path", "test-root")), mcpSyncClient -> {
assertThatCode(() -> {
mcpSyncClient.initialize();
@@ -300,7 +301,7 @@ void testInitializeWithRootsListProviders() {
@Test
void testAddRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
Root newRoot = new Root("file:///new/test/path", "new-test-root");
assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException();
});
@@ -308,14 +309,14 @@ void testAddRoot() {
@Test
void testAddRootWithNullValue() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null");
});
}
@Test
void testRemoveRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
Root root = new Root("file:///test/path/to/remove", "root-to-remove");
assertThatCode(() -> {
mcpSyncClient.addRoot(root);
@@ -326,7 +327,7 @@ void testRemoveRoot() {
@Test
void testRemoveNonExistentRoot() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri"))
.hasMessageContaining("Root with uri 'nonexistent-uri' not found");
});
@@ -340,7 +341,7 @@ void testReadResourceWithoutInitialization() {
@Test
void testReadResource() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourcesResult resources = mcpSyncClient.listResources(null);
@@ -361,7 +362,7 @@ void testListResourceTemplatesWithoutInitialization() {
@Test
void testListResourceTemplates() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null);
@@ -372,7 +373,7 @@ void testListResourceTemplates() {
// @Test
void testResourceSubscription() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
ListResourcesResult resources = mcpSyncClient.listResources(null);
if (!resources.resources().isEmpty()) {
@@ -395,7 +396,7 @@ void testNotificationHandlers() {
AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false);
AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(),
+ withClient(createMcpClientTransportProvider(),
builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true))
.resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true))
.promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)),
@@ -420,7 +421,7 @@ void testLoggingLevelsWithoutInitialization() {
@Test
void testLoggingLevels() {
- withClient(createMcpTransport(), mcpSyncClient -> {
+ withClient(createMcpClientTransportProvider(), mcpSyncClient -> {
mcpSyncClient.initialize();
// Test all logging levels
for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
@@ -432,7 +433,7 @@ void testLoggingLevels() {
@Test
void testLoggingConsumer() {
AtomicBoolean logReceived = new AtomicBoolean(false);
- withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout())
+ withClient(createMcpClientTransportProvider(), builder -> builder.requestTimeout(getRequestTimeout())
.loggingConsumer(notification -> logReceived.set(true)), client -> {
assertThatCode(() -> {
client.initialize();
@@ -443,8 +444,9 @@ void testLoggingConsumer() {
@Test
void testLoggingWithNullNotification() {
- withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
- .hasMessageContaining("Logging level must not be null"));
+ withClient(createMcpClientTransportProvider(),
+ mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null))
+ .hasMessageContaining("Logging level must not be null"));
}
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java
index fdff4b77..6a3162b3 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java
@@ -4,14 +4,14 @@
package io.modelcontextprotocol.client;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
/**
- * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}.
+ * Tests for the {@link McpSyncClient} with
*
* @author Christian Tzolov
*/
@@ -28,8 +28,8 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
.waitingFor(Wait.forHttp("/").forStatusCode(404));
@Override
- protected McpClientTransport createMcpTransport() {
- return HttpClientSseClientTransport.builder(host).build();
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
+ return HttpClientSseClientTransportProvider.builder(host).build();
}
@Override
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java
index 204cf298..de68b901 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java
@@ -4,15 +4,13 @@
package io.modelcontextprotocol.client;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
/**
- * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}.
- *
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
@@ -28,8 +26,8 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests {
.waitingFor(Wait.forHttp("/").forStatusCode(404));
@Override
- protected McpClientTransport createMcpTransport() {
- return HttpClientSseClientTransport.builder(host).build();
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
+ return HttpClientSseClientTransportProvider.builder(host).build();
}
@Override
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java
index 4510b152..3e666128 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java
@@ -12,7 +12,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.MockMcpClientTransport;
+import io.modelcontextprotocol.MockMcpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -34,16 +34,16 @@ class McpAsyncClientResponseHandlerTests {
.resources(true, true) // Enable both resources and resource templates
.build();
- private static MockMcpClientTransport initializationEnabledTransport() {
+ private static MockMcpClientTransportProvider initializationEnabledTransport() {
return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO);
}
- private static MockMcpClientTransport initializationEnabledTransport(
+ private static MockMcpClientTransportProvider initializationEnabledTransport(
McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) {
McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION,
mockServerCapabilities, mockServerInfo, "Test instructions");
- return new MockMcpClientTransport((t, message) -> {
+ return new MockMcpClientTransportProvider((t, message) -> {
if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) {
McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION,
r.id(), mockInitResult, null);
@@ -59,8 +59,10 @@ void testSuccessfulInitialization() {
.tools(false)
.resources(true, true) // Enable both resources and resource templates
.build();
- MockMcpClientTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo);
- McpAsyncClient asyncMcpClient = McpClient.async(transport).build();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(serverCapabilities,
+ serverInfo);
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider).build();
// Verify client is not initialized initially
assertThat(asyncMcpClient.isInitialized()).isFalse();
@@ -91,8 +93,8 @@ void testSuccessfulInitialization() {
@Test
void testToolsChangeNotificationHandling() throws JsonProcessingException {
- MockMcpClientTransport transport = initializationEnabledTransport();
-
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create a list to store received tools for verification
List receivedTools = new ArrayList<>();
@@ -101,7 +103,9 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException {
.fromRunnable(() -> receivedTools.addAll(tools));
// Create client with tools change consumer
- McpAsyncClient asyncMcpClient = McpClient.async(transport).toolsChangeConsumer(toolsChangeConsumer).build();
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
+ .toolsChangeConsumer(toolsChangeConsumer)
+ .build();
assertThat(asyncMcpClient.initialize().block()).isNotNull();
@@ -134,9 +138,10 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException {
@Test
void testRootsListRequestHandling() {
- MockMcpClientTransport transport = initializationEnabledTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
- McpAsyncClient asyncMcpClient = McpClient.async(transport)
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
.roots(new Root("file:///test/path", "test-root"))
.build();
@@ -162,7 +167,8 @@ void testRootsListRequestHandling() {
@Test
void testResourcesChangeNotificationHandling() {
- MockMcpClientTransport transport = initializationEnabledTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create a list to store received resources for verification
List receivedResources = new ArrayList<>();
@@ -172,7 +178,7 @@ void testResourcesChangeNotificationHandling() {
.fromRunnable(() -> receivedResources.addAll(resources));
// Create client with resources change consumer
- McpAsyncClient asyncMcpClient = McpClient.async(transport)
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
.resourcesChangeConsumer(resourcesChangeConsumer)
.build();
@@ -208,7 +214,8 @@ void testResourcesChangeNotificationHandling() {
@Test
void testPromptsChangeNotificationHandling() {
- MockMcpClientTransport transport = initializationEnabledTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create a list to store received prompts for verification
List receivedPrompts = new ArrayList<>();
@@ -218,7 +225,9 @@ void testPromptsChangeNotificationHandling() {
.fromRunnable(() -> receivedPrompts.addAll(prompts));
// Create client with prompts change consumer
- McpAsyncClient asyncMcpClient = McpClient.async(transport).promptsChangeConsumer(promptsChangeConsumer).build();
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
+ .promptsChangeConsumer(promptsChangeConsumer)
+ .build();
assertThat(asyncMcpClient.initialize().block()).isNotNull();
@@ -252,7 +261,8 @@ void testPromptsChangeNotificationHandling() {
@Test
void testSamplingCreateMessageRequestHandling() {
- MockMcpClientTransport transport = initializationEnabledTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create a test sampling handler that echoes back the input
Function> samplingHandler = request -> {
@@ -262,7 +272,7 @@ void testSamplingCreateMessageRequestHandling() {
};
// Create client with sampling capability and handler
- McpAsyncClient asyncMcpClient = McpClient.async(transport)
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
.capabilities(ClientCapabilities.builder().sampling().build())
.sampling(samplingHandler)
.build();
@@ -306,10 +316,11 @@ void testSamplingCreateMessageRequestHandling() {
@Test
void testSamplingCreateMessageRequestHandlingWithoutCapability() {
- MockMcpClientTransport transport = initializationEnabledTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create client without sampling capability
- McpAsyncClient asyncMcpClient = McpClient.async(transport)
+ McpAsyncClient asyncMcpClient = McpClient.async(transportProvider)
.capabilities(ClientCapabilities.builder().build()) // No sampling capability
.build();
@@ -340,12 +351,13 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() {
@Test
void testSamplingCreateMessageRequestHandlingWithNullHandler() {
- MockMcpClientTransport transport = new MockMcpClientTransport();
+ MockMcpClientTransportProvider transportProvider = initializationEnabledTransport();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
// Create client with sampling capability but null handler
- assertThatThrownBy(
- () -> McpClient.async(transport).capabilities(ClientCapabilities.builder().sampling().build()).build())
- .isInstanceOf(McpError.class)
+ assertThatThrownBy(() -> McpClient.async(transportProvider)
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .build()).isInstanceOf(McpError.class)
.hasMessage("Sampling handler must not be null when client capabilities include sampling");
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java
index bf473849..3b915161 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java
@@ -8,6 +8,7 @@
import java.util.List;
import io.modelcontextprotocol.MockMcpClientTransport;
+import io.modelcontextprotocol.MockMcpClientTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
@@ -28,8 +29,9 @@ class McpClientProtocolVersionTests {
@Test
void shouldUseLatestVersionByDefault() {
- MockMcpClientTransport transport = new MockMcpClientTransport();
- McpAsyncClient client = McpClient.async(transport)
+ MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
+ McpAsyncClient client = McpClient.async(transportProvider)
.clientInfo(CLIENT_INFO)
.requestTimeout(REQUEST_TIMEOUT)
.build();
@@ -61,8 +63,9 @@ void shouldUseLatestVersionByDefault() {
@Test
void shouldNegotiateSpecificVersion() {
String oldVersion = "0.1.0";
- MockMcpClientTransport transport = new MockMcpClientTransport();
- McpAsyncClient client = McpClient.async(transport)
+ MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
+ McpAsyncClient client = McpClient.async(transportProvider)
.clientInfo(CLIENT_INFO)
.requestTimeout(REQUEST_TIMEOUT)
.build();
@@ -94,8 +97,9 @@ void shouldNegotiateSpecificVersion() {
@Test
void shouldFailForUnsupportedVersion() {
String unsupportedVersion = "999.999.999";
- MockMcpClientTransport transport = new MockMcpClientTransport();
- McpAsyncClient client = McpClient.async(transport)
+ MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
+ McpAsyncClient client = McpClient.async(transportProvider)
.clientInfo(CLIENT_INFO)
.requestTimeout(REQUEST_TIMEOUT)
.build();
@@ -124,8 +128,9 @@ void shouldUseHighestVersionWhenMultipleSupported() {
String middleVersion = "0.2.0";
String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION;
- MockMcpClientTransport transport = new MockMcpClientTransport();
- McpAsyncClient client = McpClient.async(transport)
+ MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider();
+ MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport();
+ McpAsyncClient client = McpClient.async(transportProvider)
.clientInfo(CLIENT_INFO)
.requestTimeout(REQUEST_TIMEOUT)
.build();
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java
index c3908013..4f510b01 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java
@@ -4,16 +4,14 @@
package io.modelcontextprotocol.client;
-import java.time.Duration;
-
import io.modelcontextprotocol.client.transport.ServerParameters;
-import io.modelcontextprotocol.client.transport.StdioClientTransport;
-import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.client.transport.StdioClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
+import java.time.Duration;
+
/**
- * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}.
- *
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
*/
@@ -21,7 +19,7 @@
class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
@Override
- protected McpClientTransport createMcpTransport() {
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
ServerParameters stdioParams;
if (System.getProperty("os.name").toLowerCase().contains("win")) {
stdioParams = ServerParameters.builder("cmd.exe")
@@ -33,7 +31,7 @@ protected McpClientTransport createMcpTransport() {
.args("-y", "@modelcontextprotocol/server-everything", "dir")
.build();
}
- return new StdioClientTransport(stdioParams);
+ return new StdioClientTransportProvider(stdioParams);
}
protected Duration getInitializationTimeout() {
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java
index 8e75c4a3..bfef4e53 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java
@@ -5,13 +5,16 @@
package io.modelcontextprotocol.client;
import java.time.Duration;
+import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import io.modelcontextprotocol.client.transport.ServerParameters;
-import io.modelcontextprotocol.client.transport.StdioClientTransport;
+import io.modelcontextprotocol.client.transport.StdioClientTransportProvider;
+import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import reactor.core.publisher.Sinks;
@@ -20,8 +23,6 @@
import static org.assertj.core.api.Assertions.assertThat;
/**
- * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}.
- *
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
*/
@@ -29,7 +30,7 @@
class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests {
@Override
- protected McpClientTransport createMcpTransport() {
+ protected McpClientTransportProvider createMcpClientTransportProvider() {
ServerParameters stdioParams;
if (System.getProperty("os.name").toLowerCase().contains("win")) {
stdioParams = ServerParameters.builder("cmd.exe")
@@ -41,7 +42,7 @@ protected McpClientTransport createMcpTransport() {
.args("-y", "@modelcontextprotocol/server-everything", "dir")
.build();
}
- return new StdioClientTransport(stdioParams);
+ return new StdioClientTransportProvider(stdioParams);
}
@Test
@@ -49,16 +50,20 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference receivedError = new AtomicReference<>();
- McpClientTransport transport = createMcpTransport();
+ McpClientTransportProvider transportProvider = createMcpClientTransportProvider();
+ transportProvider.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ McpClientTransport transport = transportProvider.getSession().getTransport();
StepVerifier.create(transport.connect(msg -> msg)).verifyComplete();
- ((StdioClientTransport) transport).setStdErrorHandler(error -> {
+ ((StdioClientTransportProvider) transportProvider).setStdErrorHandler(error -> {
receivedError.set(error);
latch.countDown();
});
String errorMessage = "Test error";
- ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST);
+ ((StdioClientTransportProvider) transportProvider).getErrorSink()
+ .emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST);
assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue();
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
index e5178c0e..6bbd486d 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
@@ -15,6 +15,9 @@
import java.util.function.Consumer;
import java.util.function.Function;
+import com.fasterxml.jackson.core.type.TypeReference;
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import org.junit.jupiter.api.AfterEach;
@@ -35,7 +38,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
/**
- * Tests for the {@link HttpClientSseClientTransport} class.
+ * Tests for the {@link HttpClientSseClientTransportProvider.HttpClientSseClientTransport}
+ * class.
*
* @author Christian Tzolov
*/
@@ -53,14 +57,16 @@ class HttpClientSseClientTransportTests {
private TestHttpClientSseClientTransport transport;
// Test class to access protected methods
- static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport {
+ static class TestHttpClientSseClientTransport implements McpClientTransport {
+
+ McpClientTransport delegate;
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer();
- public TestHttpClientSseClientTransport(final String baseUri) {
- super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper());
+ public TestHttpClientSseClientTransport(McpClientTransport transport) {
+ this.delegate = transport;
}
public int getInboundMessageCount() {
@@ -77,6 +83,31 @@ public void simulateMessageEvent(String jsonMessage) {
inboundMessageCount.incrementAndGet();
}
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ return delegate.connect(handler);
+ }
+
+ @Override
+ public Mono connect() {
+ return delegate.connect();
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return delegate.closeGracefully();
+ }
+
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return delegate.sendMessage(message);
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return delegate.unmarshalFrom(data, typeRef);
+ }
+
}
void startContainer() {
@@ -88,8 +119,12 @@ void startContainer() {
@BeforeEach
void setUp() {
startContainer();
- transport = new TestHttpClientSseClientTransport(host);
- transport.connect(Function.identity()).block();
+ HttpClientSseClientTransportProvider transportProvider = HttpClientSseClientTransportProvider.builder(host)
+ .build();
+ transportProvider.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ transport = new TestHttpClientSseClientTransport(transportProvider.getSession().getTransport());
+ transport.connect().block();
}
@AfterEach
@@ -205,14 +240,17 @@ void testGracefulShutdown() {
@Test
void testRetryBehavior() {
// Create a client that simulates connection failures
- HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host")
+ HttpClientSseClientTransportProvider failingTransportProvider = HttpClientSseClientTransportProvider
+ .builder("http://non-existent-host")
.build();
-
+ failingTransportProvider.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ failingTransportProvider.getSession();
// Verify that the transport attempts to reconnect
StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();
// Clean up
- failingTransport.closeGracefully().block();
+ failingTransportProvider.closeGracefully().block();
}
@Test
@@ -290,12 +328,15 @@ void testCustomizeClient() {
AtomicBoolean customizerCalled = new AtomicBoolean(false);
// Create a transport with the customizer
- HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host)
+ HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host)
.customizeClient(builder -> {
builder.version(HttpClient.Version.HTTP_2);
customizerCalled.set(true);
})
.build();
+ customizedTransport.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ customizedTransport.getSession();
// Verify the customizer was called
assertThat(customizerCalled.get()).isTrue();
@@ -314,7 +355,7 @@ void testCustomizeRequest() {
AtomicReference headerValue = new AtomicReference<>();
// Create a transport with the customizer
- HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host)
+ HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host)
// Create a request customizer that adds a custom header
.customizeRequest(builder -> {
builder.header("X-Custom-Header", "test-value");
@@ -326,6 +367,9 @@ void testCustomizeRequest() {
headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null));
})
.build();
+ customizedTransport.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ customizedTransport.getSession();
// Verify the customizer was called
assertThat(customizerCalled.get()).isTrue();
@@ -345,7 +389,7 @@ void testChainedCustomizations() {
AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false);
// Create a transport with both customizers chained
- HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host)
+ HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host)
.customizeClient(builder -> {
builder.connectTimeout(Duration.ofSeconds(30));
clientCustomizerCalled.set(true);
@@ -355,6 +399,9 @@ void testChainedCustomizations() {
requestCustomizerCalled.set(true);
})
.build();
+ customizedTransport.setSessionFactory(
+ (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
+ customizedTransport.getSession();
// Verify both customizers were called
assertThat(clientCustomizerCalled.get()).isTrue();
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java
index 212a3c95..4e60e3f8 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java
@@ -5,7 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.spec.McpSchema;
import org.apache.catalina.LifecycleException;
@@ -54,7 +54,7 @@ public void before() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ this.clientBuilder = McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT)
.build());
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java
index 135de83f..40f7be81 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java
@@ -13,7 +13,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.spec.McpError;
@@ -76,9 +76,12 @@ public void before() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ HttpClientSseClientTransportProvider transportProvider = HttpClientSseClientTransportProvider
+ .builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
- .build());
+ .objectMapper(new ObjectMapper())
+ .build();
+ this.clientBuilder = McpClient.sync(transportProvider);
}
@AfterEach
diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java
index f72be43e..4d7a0705 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java
@@ -9,6 +9,7 @@
import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.MockMcpClientTransport;
+import io.modelcontextprotocol.MockMcpClientTransportProvider;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -41,13 +42,18 @@ class McpClientSessionTests {
private McpClientSession session;
- private MockMcpClientTransport transport;
+ private MockMcpClientTransportProvider.MockMcpClientTransport transport;
+
+ private MockMcpClientTransportProvider transportProvider;
@BeforeEach
void setUp() {
- transport = new MockMcpClientTransport();
- session = new McpClientSession(TIMEOUT, transport, Map.of(),
- Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params))));
+ transportProvider = new MockMcpClientTransportProvider();
+ transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
+ Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))));
+ session = transportProvider.getSession();
+ transport = transportProvider.getTransport();
+
}
@AfterEach
@@ -139,8 +145,11 @@ void testRequestHandling() {
String echoMessage = "Hello MCP!";
Map> requestHandlers = Map.of(ECHO_METHOD,
params -> Mono.just(params));
- transport = new MockMcpClientTransport();
- session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of());
+ transportProvider = new MockMcpClientTransportProvider();
+ transportProvider
+ .setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()));
+ session = transportProvider.getSession();
+ transport = transportProvider.getTransport();
// Simulate incoming request
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD,
@@ -159,9 +168,11 @@ void testRequestHandling() {
void testNotificationHandling() {
Sinks.One receivedParams = Sinks.one();
- transport = new MockMcpClientTransport();
- session = new McpClientSession(TIMEOUT, transport, Map.of(),
- Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))));
+ transportProvider = new MockMcpClientTransportProvider();
+ transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
+ Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))));
+ session = transportProvider.getSession();
+ transport = transportProvider.getTransport();
// Simulate incoming notification from the server
Map notificationParams = Map.of("status", "ready");