diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java deleted file mode 100644 index 37abe295..00000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ /dev/null @@ -1,408 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.client.transport; - -import java.io.IOException; -import java.util.function.BiConsumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.publisher.SynchronousSink; -import reactor.core.scheduler.Schedulers; -import reactor.util.retry.Retry; -import reactor.util.retry.Retry.RetrySignal; - -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.client.WebClient; - -/** - * Server-Sent Events (SSE) implementation of the - * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification. - * - *

- * This transport establishes a bidirectional communication channel where: - *

- * - *

- * The message flow follows these steps: - *

    - *
  1. The client establishes an SSE connection to the server's /sse endpoint
  2. - *
  3. The server sends an 'endpoint' event containing the URI for sending messages
  4. - *
- * - * This implementation uses {@link WebClient} for HTTP communications and supports JSON - * serialization/deserialization of messages. - * - * @author Christian Tzolov - * @see MCP - * HTTP with SSE Transport Specification - */ -public class WebFluxSseClientTransport implements McpClientTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); - - /** - * Event type for JSON-RPC messages received through the SSE connection. The server - * sends messages with this event type to transmit JSON-RPC protocol data. - */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for receiving the message endpoint URI from the server. The server MUST - * send this event when a client connects, providing the URI where the client should - * send its messages via HTTP POST. - */ - private static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. This - * endpoint is used to establish the SSE connection with the server. - */ - private static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** - * Type reference for parsing SSE events containing string data. - */ - private static final ParameterizedTypeReference> SSE_TYPE = new ParameterizedTypeReference<>() { - }; - - /** - * WebClient instance for handling both SSE connections and HTTP POST requests. Used - * for establishing the SSE connection and sending outbound messages. - */ - private final WebClient webClient; - - /** - * ObjectMapper for serializing outbound messages and deserializing inbound messages. - * Handles conversion between JSON-RPC messages and their string representation. - */ - protected ObjectMapper objectMapper; - - /** - * Subscription for the SSE connection handling inbound messages. Used for cleanup - * during transport shutdown. - */ - private Disposable inboundSubscription; - - /** - * Flag indicating if the transport is in the process of shutting down. Used to - * prevent new operations during shutdown and handle cleanup gracefully. - */ - private volatile boolean isClosing = false; - - /** - * Sink for managing the message endpoint URI provided by the server. Stores the most - * recent endpoint URI and makes it available for outbound message processing. - */ - protected final Sinks.One messageEndpointSink = Sinks.one(); - - /** - * The SSE endpoint URI provided by the server. Used for sending outbound messages via - * HTTP POST requests. - */ - private String sseEndpoint; - - /** - * Constructs a new SseClientTransport with the specified WebClient builder. Uses a - * default ObjectMapper instance for JSON processing. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @throws IllegalArgumentException if webClientBuilder is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { - this(webClientBuilder, new ObjectMapper()); - } - - /** - * Constructs a new SseClientTransport with the specified WebClient builder and - * ObjectMapper. Initializes both inbound and outbound message processing pipelines. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @param objectMapper the ObjectMapper to use for JSON processing - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new SseClientTransport with the specified WebClient builder and - * ObjectMapper. Initializes both inbound and outbound message processing pipelines. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @param objectMapper the ObjectMapper to use for JSON processing - * @param sseEndpoint the SSE endpoint URI to use for establishing the connection - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); - - this.objectMapper = objectMapper; - this.webClient = webClientBuilder.build(); - this.sseEndpoint = sseEndpoint; - } - - /** - * Establishes a connection to the MCP server using Server-Sent Events (SSE). This - * method initiates the SSE connection and sets up the message processing pipeline. - * - *

- * The connection process follows these steps: - *

    - *
  1. Establishes an SSE connection to the server's /sse endpoint
  2. - *
  3. Waits for the server to send an 'endpoint' event with the message posting - * URI
  4. - *
  5. Sets up message handling for incoming JSON-RPC messages
  6. - *
- * - *

- * The connection is considered established only after receiving the endpoint event - * from the server. - * @param handler a function that processes incoming JSON-RPC messages and returns - * responses - * @return a Mono that completes when the connection is fully established - * @throws McpError if there's an error processing SSE events or if an unrecognized - * event type is received - */ - @Override - public Mono connect(Function, Mono> handler) { - Flux> events = eventStream(); - this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { - if (ENDPOINT_EVENT_TYPE.equals(event.event())) { - String messageEndpointUri = event.data(); - if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { - s.complete(); - } - else { - // TODO: clarify with the spec if multiple events can be - // received - s.error(new McpError("Failed to handle SSE endpoint event")); - } - } - else if (MESSAGE_EVENT_TYPE.equals(event.event())) { - try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); - s.next(message); - } - catch (IOException ioException) { - s.error(ioException); - } - } - else { - s.error(new McpError("Received unrecognized SSE event type: " + event.event())); - } - }).transform(handler)).subscribe(); - - // The connection is established once the server sends the endpoint event - return messageEndpointSink.asMono().then(); - } - - /** - * Sends a JSON-RPC message to the server using the endpoint provided during - * connection. - * - *

- * Messages are sent via HTTP POST requests to the server-provided endpoint URI. The - * message is serialized to JSON before transmission. If the transport is in the - * process of closing, the message send operation is skipped gracefully. - * @param message the JSON-RPC message to send - * @return a Mono that completes when the message has been sent successfully - * @throws RuntimeException if message serialization fails - */ - @Override - public Mono sendMessage(JSONRPCMessage message) { - // The messageEndpoint is the endpoint URI to send the messages - // It is provided by the server as part of the endpoint event - return messageEndpointSink.asMono().flatMap(messageEndpointUri -> { - if (isClosing) { - return Mono.empty(); - } - try { - String jsonText = this.objectMapper.writeValueAsString(message); - return webClient.post() - .uri(messageEndpointUri) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(jsonText) - .retrieve() - .toBodilessEntity() - .doOnSuccess(response -> { - logger.debug("Message sent successfully"); - }) - .doOnError(error -> { - if (!isClosing) { - logger.error("Error sending message: {}", error.getMessage()); - } - }); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } - }).then(); // TODO: Consider non-200-ok response - } - - /** - * Initializes and starts the inbound SSE event processing. Establishes the SSE - * connection and sets up event handling for both message and endpoint events. - * Includes automatic retry logic for handling transient connection failures. - */ - // visible for tests - protected Flux> eventStream() {// @formatter:off - return this.webClient - .get() - .uri(this.sseEndpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .retrieve() - .bodyToFlux(SSE_TYPE) - .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); - } // @formatter:on - - /** - * Retry handler for the inbound SSE stream. Implements the retry logic for handling - * connection failures and other errors. - */ - private BiConsumer> inboundRetryHandler = (retrySpec, sink) -> { - if (isClosing) { - logger.debug("SSE connection closed during shutdown"); - sink.error(retrySpec.failure()); - return; - } - if (retrySpec.failure() instanceof IOException) { - logger.debug("Retrying SSE connection after IO error"); - sink.next(retrySpec); - return; - } - logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage()); - sink.error(retrySpec.failure()); - }; - - /** - * Implements graceful shutdown of the transport. Cleans up all resources including - * subscriptions and schedulers. Ensures orderly shutdown of both inbound and outbound - * message processing. - * @return a Mono that completes when shutdown is finished - */ - @Override - public Mono closeGracefully() { // @formatter:off - return Mono.fromRunnable(() -> { - isClosing = true; - - // Dispose of subscriptions - - if (inboundSubscription != null) { - inboundSubscription.dispose(); - } - - }) - .then() - .subscribeOn(Schedulers.boundedElastic()); - } // @formatter:on - - /** - * Unmarshalls data from a generic Object into the specified type using the configured - * ObjectMapper. - * - *

- * This method is particularly useful when working with JSON-RPC parameters or result - * objects that need to be converted to specific Java types. It leverages Jackson's - * type conversion capabilities to handle complex object structures. - * @param the target type to convert the data into - * @param data the source object to convert - * @param typeRef the TypeReference describing the target type - * @return the unmarshalled object of type T - * @throws IllegalArgumentException if the conversion cannot be performed - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Creates a new builder for {@link WebFluxSseClientTransport}. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @return a new builder instance - */ - public static Builder builder(WebClient.Builder webClientBuilder) { - return new Builder(webClientBuilder); - } - - /** - * Builder for {@link WebFluxSseClientTransport}. - */ - public static class Builder { - - private final WebClient.Builder webClientBuilder; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private ObjectMapper objectMapper = new ObjectMapper(); - - /** - * Creates a new builder with the specified WebClient.Builder. - * @param webClientBuilder the WebClient.Builder to use - */ - public Builder(WebClient.Builder webClientBuilder) { - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - this.webClientBuilder = webClientBuilder; - } - - /** - * Sets the SSE endpoint path. - * @param sseEndpoint the SSE endpoint path - * @return this builder - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper - * @return this builder - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds a new {@link WebFluxSseClientTransport} instance. - * @return a new transport instance - */ - public WebFluxSseClientTransport build() { - return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java new file mode 100644 index 00000000..2033682d --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java @@ -0,0 +1,462 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.SynchronousSink; +import reactor.core.scheduler.Schedulers; +import reactor.util.retry.Retry; + +import java.io.IOException; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import static io.modelcontextprotocol.util.Utils.getSessionIdFromUrl; + +/** + * Server-Sent Events (SSE) implementation of the + * {@link io.modelcontextprotocol.spec.McpClientTransportProvider} that provides SSE + * client transport that follows the MCP HTTP with SSE transport specification. + * + *

+ * This transport establishes a bidirectional communication channel where: + *

+ * + *

+ * The message flow follows these steps: + *

    + *
  1. The client establishes an SSE connection to the server's /sse endpoint
  2. + *
  3. The server sends an 'endpoint' event containing the URI for sending messages
  4. + *
+ * + * This implementation uses {@link WebClient} for HTTP communications and supports JSON + * serialization/deserialization of messages. + * + * @author Christian Tzolov + * @author Jermaine Hua + * @see MCP + * HTTP with SSE Transport Specification + */ +public class WebFluxSseClientTransportProvider implements McpClientTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransportProvider.class); + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for receiving the message endpoint URI from the server. The server MUST + * send this event when a client connects, providing the URI where the client should + * send its messages via HTTP POST. + */ + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. This + * endpoint is used to establish the SSE connection with the server. + */ + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** + * Type reference for parsing SSE events containing string data. + */ + private static final ParameterizedTypeReference> SSE_TYPE = new ParameterizedTypeReference<>() { + }; + + /** + * WebClient instance for handling both SSE connections and HTTP POST requests. Used + * for establishing the SSE connection and sending outbound messages. + */ + private final WebClient webClient; + + /** + * ObjectMapper for serializing outbound messages and deserializing inbound messages. + * Handles conversion between JSON-RPC messages and their string representation. + */ + protected ObjectMapper objectMapper; + + /** + * Flag indicating if the transport is in the process of shutting down. Used to + * prevent new operations during shutdown and handle cleanup gracefully. + */ + private volatile boolean isClosing = false; + + /** + * The SSE endpoint URI provided by the server. Used for sending outbound messages via + * HTTP POST requests. + */ + private final String sseEndpoint; + + /** + * Retry handler for the inbound SSE stream. Implements the retry logic for handling + * connection failures and other errors. + */ + private final BiConsumer> inboundRetryHandler = (retrySpec, sink) -> { + if (isClosing) { + logger.debug("SSE connection closed during shutdown"); + sink.error(retrySpec.failure()); + return; + } + if (retrySpec.failure() instanceof IOException) { + logger.debug("Retrying SSE connection after IO error"); + sink.next(retrySpec); + return; + } + logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage()); + sink.error(retrySpec.failure()); + }; + + /** + * Session factory for creating new sessions + */ + private McpClientSession.Factory sessionFactory; + + /** + * Active client session + */ + private McpClientSession session; + + /** + * Constructs a new SseClientTransportProvider with the specified WebClient builder. + * Uses a default ObjectMapper instance for JSON processing. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @throws IllegalArgumentException if webClientBuilder is null + */ + public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder) { + this(webClientBuilder, new ObjectMapper()); + } + + /** + * Constructs a new SseClientTransportProvider with the specified WebClient builder + * and ObjectMapper. Initializes both inbound and outbound message processing + * pipelines. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @param objectMapper the ObjectMapper to use for JSON processing + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { + this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); + } + + /** + * Constructs a new SseClientTransportProvider with the specified WebClient builder + * and ObjectMapper. Initializes both inbound and outbound message processing + * pipelines. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @param objectMapper the ObjectMapper to use for JSON processing + * @param sseEndpoint the SSE endpoint URI to use for establishing the connection + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, + String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); + + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.sseEndpoint = sseEndpoint; + } + + @Override + public void setSessionFactory(McpClientSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public McpClientSession getSession() { + if (session != null) { + return session; + } + McpClientTransport mcpClientTransport = new WebFluxSseClientTransport(); + this.session = sessionFactory.create(mcpClientTransport); + return session; + } + + @Override + public Mono closeGracefully() { + return session.closeGracefully(); + } + + /** + * Initializes and starts the inbound SSE event processing. Establishes the SSE + * connection and sets up event handling for both message and endpoint events. + * Includes automatic retry logic for handling transient connection failures. + */ + // visible for tests + protected Flux> eventStream() {// @formatter:off + return webClient + .get() + .uri(sseEndpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .retrieve() + .bodyToFlux(SSE_TYPE) + .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); + } // @formatter:on + + protected class WebFluxSseClientTransport implements McpClientTransport { + + /** + * Subscription for the SSE connection handling inbound messages. Used for cleanup + * during transport shutdown. + */ + private Disposable inboundSubscription; + + /** + * Sink for managing the message endpoint URI provided by the server. Stores the + * most recent endpoint URI and makes it available for outbound message + * processing. + */ + protected final Sinks.One messageEndpointSink = Sinks.one(); + + public WebFluxSseClientTransport() { + } + + /** + * Establishes a connection to the MCP server using Server-Sent Events (SSE). This + * method initiates the SSE connection and sets up the message processing + * pipeline. + * + *

+ * The connection process follows these steps: + *

    + *
  1. Establishes an SSE connection to the server's /sse endpoint
  2. + *
  3. Waits for the server to send an 'endpoint' event with the message posting + * URI
  4. + *
  5. Sets up message handling for incoming JSON-RPC messages
  6. + *
+ * + *

+ * The connection is considered established only after receiving the endpoint + * event from the server. + * @param handler a function that processes incoming JSON-RPC messages and returns + * responses + * @return a Mono that completes when the connection is fully established + * @throws McpError if there's an error processing SSE events or if an + * unrecognized event type is received + */ + @Override + public Mono connect(Function, Mono> handler) { + return connect(); + } + + @Override + public Mono connect() { + Flux> events = eventStream(); + inboundSubscription = events + .concatMap(event -> Mono.just(event).handle((e, s) -> { + if (ENDPOINT_EVENT_TYPE.equals(event.event())) { + String messageEndpointUri = event.data(); + if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + session.setId(getSessionIdFromUrl(messageEndpointUri)); + s.complete(); + } + else { + // TODO: clarify with the spec if multiple events can be + // received + s.error(new McpError("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + event.data()); + session.handle(message).subscribe(); + s.next(message); + } + catch (IOException ioException) { + s.error(ioException); + } + } + else { + s.error(new McpError("Received unrecognized SSE event type: " + event.event())); + } + })) + .subscribe(); + + // The connection is established once the server sends the endpoint event + return messageEndpointSink.asMono().then(); + } + + /** + * Sends a JSON-RPC message to the server using the endpoint provided during + * connection. + * + *

+ * Messages are sent via HTTP POST requests to the server-provided endpoint URI. + * The message is serialized to JSON before transmission. If the transport is in + * the process of closing, the message send operation is skipped gracefully. + * @param message the JSON-RPC message to send + * @return a Mono that completes when the message has been sent successfully + * @throws RuntimeException if message serialization fails + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + // The messageEndpoint is the endpoint URI to send the messages + // It is provided by the server as part of the endpoint event + return messageEndpointSink.asMono().flatMap(messageEndpointUri -> { + if (isClosing) { + return Mono.empty(); + } + try { + String jsonText = objectMapper.writeValueAsString(message); + return webClient.post() + .uri(messageEndpointUri) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(jsonText) + .retrieve() + .toBodilessEntity() + .doOnSuccess(response -> { + logger.debug("Message sent successfully"); + }) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); + } + catch (IOException e) { + if (!isClosing) { + return Mono.error(new RuntimeException("Failed to serialize message", e)); + } + return Mono.empty(); + } + }).then(); // TODO: Consider non-200-ok response + } + + /** + * Implements graceful shutdown of the transport. Cleans up all resources + * including subscriptions and schedulers. Ensures orderly shutdown of both + * inbound and outbound message processing. + * @return a Mono that completes when shutdown is finished + */ + @Override + public Mono closeGracefully() { // @formatter:off + return Mono.fromRunnable(() -> { + isClosing = true; + + // Dispose of subscriptions + + if (inboundSubscription != null) { + inboundSubscription.dispose(); + } + + }) + .then() + .subscribeOn(Schedulers.boundedElastic()); + } // @formatter:on + + /** + * Unmarshalls data from a generic Object into the specified type using the + * configured ObjectMapper. + * + *

+ * This method is particularly useful when working with JSON-RPC parameters or + * result objects that need to be converted to specific Java types. It leverages + * Jackson's type conversion capabilities to handle complex object structures. + * @param the target type to convert the data into + * @param data the source object to convert + * @param typeRef the TypeReference describing the target type + * @return the unmarshalled object of type T + * @throws IllegalArgumentException if the conversion cannot be performed + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + } + + /** + * Creates a new builder for + * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider}. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @return a new builder instance + */ + public static WebFluxSseClientTransportProvider.Builder builder(WebClient.Builder webClientBuilder) { + return new WebFluxSseClientTransportProvider.Builder(webClientBuilder); + } + + /** + * Builder for + * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider}. + */ + public static class Builder { + + private final WebClient.Builder webClientBuilder; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Creates a new builder with the specified WebClient.Builder. + * @param webClientBuilder the WebClient.Builder to use + */ + public Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public WebFluxSseClientTransportProvider.Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the object mapper for JSON serialization/deserialization. + * @param objectMapper the object mapper + * @return this builder + */ + public WebFluxSseClientTransportProvider.Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a new + * {@link io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider} + * instance. + * @return a new transport instance + */ + public WebFluxSseClientTransportProvider build() { + return new WebFluxSseClientTransportProvider(webClientBuilder, objectMapper, sseEndpoint); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 76f908b8..02ca8223 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -14,8 +14,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; @@ -78,14 +78,14 @@ public void before() { this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); clientBuilders.put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build())); clientBuilders.put("webflux", - McpClient - .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); + McpClient.sync(WebFluxSseClientTransportProvider + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build())); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c1449..12d5dfef 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -6,8 +6,8 @@ import java.time.Duration; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -15,8 +15,6 @@ import org.springframework.web.reactive.function.client.WebClient; /** - * Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}. - * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout @@ -31,9 +29,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); - @Override - protected McpClientTransport createMcpTransport() { - return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); + protected McpClientTransportProvider createMcpClientTransportProvider() { + return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host)); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6d..76c2dc40 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -6,8 +6,8 @@ import java.time.Duration; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -15,8 +15,6 @@ import org.springframework.web.reactive.function.client.WebClient; /** - * Tests for the {@link McpSyncClient} with {@link WebFluxSseClientTransport}. - * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout @@ -32,8 +30,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { - return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); + protected McpClientTransportProvider createMcpClientTransportProvider() { + return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host)); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da..87714baf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -9,7 +9,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; @@ -31,8 +34,6 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; /** - * Tests for the {@link WebFluxSseClientTransport} class. - * * @author Christian Tzolov */ @Timeout(15) @@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests { .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); - private TestSseClientTransport transport; + private TestSseClientTransportProvider transportProvider; + + private McpClientTransport transport; private WebClient.Builder webClientBuilder; private ObjectMapper objectMapper; // Test class to access protected methods - static class TestSseClientTransport extends WebFluxSseClientTransport { + static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider { private final AtomicInteger inboundMessageCount = new AtomicInteger(0); private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { + public TestSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { super(webClientBuilder, objectMapper); } @@ -69,7 +72,7 @@ protected Flux> eventStream() { } public String getLastEndpoint() { - return messageEndpointSink.asMono().block(); + return ((WebFluxSseClientTransport) getSession().getTransport()).messageEndpointSink.asMono().block(); } public int getInboundMessageCount() { @@ -99,7 +102,10 @@ void setUp() { startContainer(); webClientBuilder = WebClient.builder().baseUrl(host); objectMapper = new ObjectMapper(); - transport = new TestSseClientTransport(webClientBuilder, objectMapper); + transportProvider = new TestSseClientTransportProvider(webClientBuilder, objectMapper); + transportProvider.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transport = transportProvider.getSession().getTransport(); transport.connect(Function.identity()).block(); } @@ -117,15 +123,16 @@ void cleanup() { @Test void testEndpointEventHandling() { - assertThat(transport.getLastEndpoint()).startsWith("/message?"); + assertThat(transportProvider.getLastEndpoint()).startsWith("/message?"); } @Test void constructorValidation() { - assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("WebClient.Builder must not be null"); - assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null)) + assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(webClientBuilder, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("ObjectMapper must not be null"); } @@ -133,28 +140,45 @@ void constructorValidation() { @Test void testBuilderPattern() { // Test default builder - WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build(); - assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException(); + WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider + .builder(webClientBuilder) + .build(); + transportProvider1.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transportProvider1.getSession(); + assertThatCode(() -> transportProvider1.closeGracefully().block()).doesNotThrowAnyException(); // Test builder with custom ObjectMapper ObjectMapper customMapper = new ObjectMapper(); - WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) + WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider + .builder(webClientBuilder) .objectMapper(customMapper) .build(); - assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); + transportProvider2.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transportProvider2.getSession(); + assertThatCode(() -> transportProvider2.closeGracefully().block()).doesNotThrowAnyException(); // Test builder with custom SSE endpoint - WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder) + WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider + .builder(webClientBuilder) .sseEndpoint("/custom-sse") .build(); - assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException(); + transportProvider3.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transportProvider3.getSession(); + assertThatCode(() -> transportProvider3.closeGracefully().block()).doesNotThrowAnyException(); // Test builder with all custom parameters - WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) + WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider + .builder(webClientBuilder) .objectMapper(customMapper) .sseEndpoint("/custom-sse") .build(); - assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); + transportProvider4.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transportProvider4.getSession(); + assertThatCode(() -> transportProvider4.closeGracefully().block()).doesNotThrowAnyException(); } @Test @@ -164,7 +188,7 @@ void testMessageProcessing() { Map.of("key", "value")); // Simulate receiving the message - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "test-method", @@ -176,13 +200,13 @@ void testMessageProcessing() { // Subscribe to messages and verify StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - assertThat(transport.getInboundMessageCount()).isEqualTo(1); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1); } @Test void testResponseMessageProcessing() { // Simulate receiving a response message - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "id": "test-id", @@ -197,13 +221,13 @@ void testResponseMessageProcessing() { // Verify message handling StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - assertThat(transport.getInboundMessageCount()).isEqualTo(1); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1); } @Test void testErrorMessageProcessing() { // Simulate receiving an error message - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "id": "test-id", @@ -221,13 +245,13 @@ void testErrorMessageProcessing() { // Verify message handling StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - assertThat(transport.getInboundMessageCount()).isEqualTo(1); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1); } @Test void testNotificationMessageProcessing() { // Simulate receiving a notification message (no id) - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "update", @@ -236,7 +260,7 @@ void testNotificationMessageProcessing() { """); // Verify the notification was processed - assertThat(transport.getInboundMessageCount()).isEqualTo(1); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1); } @Test @@ -252,7 +276,7 @@ void testGracefulShutdown() { StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(0); } @Test @@ -260,19 +284,23 @@ void testRetryBehavior() { // Create a WebClient that simulates connection failures WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host"); - WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build(); + WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider + .builder(failingWebClientBuilder) + .build(); + failingTransportProvider.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); // Clean up - failingTransport.closeGracefully().block(); + failingTransportProvider.getSession().getTransport().closeGracefully().block(); } @Test void testMultipleMessageProcessing() { // Simulate receiving multiple messages in sequence - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "method1", @@ -281,7 +309,7 @@ void testMultipleMessageProcessing() { } """); - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "method2", @@ -301,13 +329,13 @@ void testMultipleMessageProcessing() { StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); // Verify message count - assertThat(transport.getInboundMessageCount()).isEqualTo(2); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(2); } @Test void testMessageOrderPreservation() { // Simulate receiving messages in a specific order - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "first", @@ -316,7 +344,7 @@ void testMessageOrderPreservation() { } """); - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "second", @@ -325,7 +353,7 @@ void testMessageOrderPreservation() { } """); - transport.simulateMessageEvent(""" + transportProvider.simulateMessageEvent(""" { "jsonrpc": "2.0", "method": "third", @@ -335,7 +363,7 @@ void testMessageOrderPreservation() { """); // Verify message count and order - assertThat(transport.getInboundMessageCount()).isEqualTo(3); + assertThat(transportProvider.getInboundMessageCount()).isEqualTo(3); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 0e81104b..96664b72 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -5,7 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; import org.apache.catalina.LifecycleException; @@ -49,7 +49,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + var clientTransport = HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index be01365a..bfc69150 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -11,7 +11,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -85,7 +85,8 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + clientBuilder = McpClient + .sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT).build()); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8ea..3e8cfbaa 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected McpClientTransportProvider createMcpClientTransportProvider(); protected void onStart() { } @@ -66,11 +66,12 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(McpClientTransport transport) { + McpAsyncClient client(McpClientTransportProvider transport) { return client(transport, Function.identity()); } - McpAsyncClient client(McpClientTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransportProvider transport, + Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -85,11 +86,11 @@ McpAsyncClient client(McpClientTransport transport, Function c) { + void withClient(McpClientTransportProvider transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(McpClientTransport transport, Function customizer, + void withClient(McpClientTransportProvider transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { @@ -111,7 +112,7 @@ void tearDown() { } void verifyInitializationTimeout(Function> operation, String action) { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) .expectSubscription() .thenAwait(getInitializationTimeout()) @@ -126,7 +127,7 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpClientTransportProvider()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @@ -138,7 +139,7 @@ void testListToolsWithoutInitialization() { @Test void testListTools() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) .consumeNextWith(result -> { assertThat(result.tools()).isNotNull().isNotEmpty(); @@ -158,7 +159,7 @@ void testPingWithoutInitialization() { @Test void testPing() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) .expectNextCount(1) .verifyComplete(); @@ -173,7 +174,7 @@ void testCallToolWithoutInitialization() { @Test void testCallTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) @@ -189,7 +190,7 @@ void testCallTool() { @Test void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); @@ -207,7 +208,7 @@ void testListResourcesWithoutInitialization() { @Test void testListResources() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) .consumeNextWith(resources -> { assertThat(resources).isNotNull().satisfies(result -> { @@ -226,7 +227,7 @@ void testListResources() { @Test void testMcpAsyncClientState() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { assertThat(mcpAsyncClient).isNotNull(); }); } @@ -238,7 +239,7 @@ void testListPromptsWithoutInitialization() { @Test void testListPrompts() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) .consumeNextWith(prompts -> { assertThat(prompts).isNotNull().satisfies(result -> { @@ -263,7 +264,7 @@ void testGetPromptWithoutInitialization() { @Test void testGetPrompt() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier .create(mcpAsyncClient.initialize() .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) @@ -285,7 +286,7 @@ void testRootsListChangedWithoutInitialization() { @Test void testRootsListChanged() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) .verifyComplete(); }); @@ -293,15 +294,15 @@ void testRootsListChanged() { @Test void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - client -> { + withClient(createMcpClientTransportProvider(), + builder -> builder.roots(new Root("file:///test/path", "test-root")), client -> { StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); }); } @Test void testAddRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { Root newRoot = new Root("file:///new/test/path", "new-test-root"); StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); }); @@ -309,7 +310,7 @@ void testAddRoot() { @Test void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) .verify(); @@ -318,7 +319,7 @@ void testAddRootWithNullValue() { @Test void testRemoveRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); @@ -328,7 +329,7 @@ void testRemoveRoot() { @Test void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) @@ -339,7 +340,7 @@ void testRemoveNonExistentRoot() { @Test @Disabled void testReadResource() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { if (!resources.resources().isEmpty()) { Resource firstResource = resources.resources().get(0); @@ -359,7 +360,7 @@ void testListResourceTemplatesWithoutInitialization() { @Test void testListResourceTemplates() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) .consumeNextWith(result -> { assertThat(result).isNotNull(); @@ -371,7 +372,7 @@ void testListResourceTemplates() { // @Test void testResourceSubscription() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { if (!resources.resources().isEmpty()) { Resource firstResource = resources.resources().get(0); @@ -394,7 +395,7 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer( @@ -414,7 +415,7 @@ void testInitializeWithSamplingCapability() { .message("test") .model("test-model") .build(); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), client -> { StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); @@ -432,8 +433,8 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), - client -> + withClient(createMcpClientTransportProvider(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { assertThat(result).isNotNull(); @@ -453,7 +454,7 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier .create(mcpAsyncClient.initialize() .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) @@ -465,7 +466,7 @@ void testLoggingLevels() { void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), client -> { StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); @@ -477,7 +478,7 @@ void testLoggingConsumer() { @Test void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) .verify(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f8..0fe30b9d 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected McpClientTransportProvider createMcpClientTransportProvider(); protected void onStart() { } @@ -65,11 +65,12 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(McpClientTransport transport) { + McpSyncClient client(McpClientTransportProvider transport) { return client(transport, Function.identity()); } - McpSyncClient client(McpClientTransport transport, Function customizer) { + McpSyncClient client(McpClientTransportProvider transport, + Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -84,11 +85,11 @@ McpSyncClient client(McpClientTransport transport, Function c) { + void withClient(McpClientTransportProvider transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(McpClientTransport transport, Function customizer, + void withClient(McpClientTransportProvider transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { @@ -120,7 +121,7 @@ void verifyNotificationTimesOut(Consumer operation, String ac } void verifyCallTimesOut(Function blockingOperation, String action) { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { // This scheduler is not replaced by virtual time scheduler Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); @@ -147,7 +148,7 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpClientTransportProvider()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @@ -159,7 +160,7 @@ void testListToolsWithoutInitialization() { @Test void testListTools() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListToolsResult tools = mcpSyncClient.listTools(null); @@ -181,7 +182,7 @@ void testCallToolsWithoutInitialization() { @Test void testCallTools() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); @@ -205,7 +206,7 @@ void testPingWithoutInitialization() { @Test void testPing() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); }); @@ -219,7 +220,7 @@ void testCallToolWithoutInitialization() { @Test void testCallTool() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); @@ -234,7 +235,7 @@ void testCallTool() { @Test void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); @@ -249,7 +250,7 @@ void testRootsListChangedWithoutInitialization() { @Test void testRootsListChanged() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); }); @@ -262,7 +263,7 @@ void testListResourcesWithoutInitialization() { @Test void testListResources() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourcesResult resources = mcpSyncClient.listResources(null); @@ -280,15 +281,15 @@ void testListResources() { @Test void testClientSessionState() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThat(mcpSyncClient).isNotNull(); }); } @Test void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), + builder -> builder.roots(new Root("file:///test/path", "test-root")), mcpSyncClient -> { assertThatCode(() -> { mcpSyncClient.initialize(); @@ -299,7 +300,7 @@ void testInitializeWithRootsListProviders() { @Test void testAddRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { Root newRoot = new Root("file:///new/test/path", "new-test-root"); assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); }); @@ -307,14 +308,14 @@ void testAddRoot() { @Test void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); }); } @Test void testRemoveRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); assertThatCode(() -> { mcpSyncClient.addRoot(root); @@ -325,7 +326,7 @@ void testRemoveRoot() { @Test void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); }); @@ -339,7 +340,7 @@ void testReadResourceWithoutInitialization() { @Test void testReadResource() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourcesResult resources = mcpSyncClient.listResources(null); @@ -360,7 +361,7 @@ void testListResourceTemplatesWithoutInitialization() { @Test void testListResourceTemplates() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); @@ -371,7 +372,7 @@ void testListResourceTemplates() { // @Test void testResourceSubscription() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { ListResourcesResult resources = mcpSyncClient.listResources(null); if (!resources.resources().isEmpty()) { @@ -394,7 +395,7 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), @@ -419,7 +420,7 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); // Test all logging levels for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { @@ -431,7 +432,7 @@ void testLoggingLevels() { @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + withClient(createMcpClientTransportProvider(), builder -> builder.requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)), client -> { assertThatCode(() -> { client.initialize(); @@ -442,8 +443,9 @@ void testLoggingConsumer() { @Test void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null")); + withClient(createMcpClientTransportProvider(), + mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index df099836..fd4a51cf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -8,16 +8,19 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -30,6 +33,8 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; @@ -100,7 +105,7 @@ public class McpAsyncClient { /** * Client capabilities. */ - private final McpSchema.ClientCapabilities clientCapabilities; + private final ClientCapabilities clientCapabilities; /** * Client implementation information. @@ -135,10 +140,7 @@ public class McpAsyncClient { */ private Function> samplingHandler; - /** - * Client transport implementation. - */ - private final McpTransport transport; + private final ObjectMapper objectMapper; /** * Supported protocol versions. @@ -148,21 +150,21 @@ public class McpAsyncClient { /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. - * @param transport the transport to use. + * @param mcpClientTransportProvider the transport to use. * @param requestTimeout the session request-response timeout. * @param initializationTimeout the max timeout to await for the client-server * @param features the MCP Client supported features. */ - McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + McpAsyncClient(McpClientTransportProvider mcpClientTransportProvider, Duration requestTimeout, + Duration initializationTimeout, McpClientFeatures.Async features, ObjectMapper objectMapper) { - Assert.notNull(transport, "Transport must not be null"); + Assert.notNull(mcpClientTransportProvider, "Transport provider must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); this.clientInfo = features.clientInfo(); this.clientCapabilities = features.clientCapabilities(); - this.transport = transport; + this.objectMapper = objectMapper; this.roots = new ConcurrentHashMap<>(features.roots()); this.initializationTimeout = initializationTimeout; @@ -228,8 +230,9 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); - + mcpClientTransportProvider.setSessionFactory( + (trans) -> new McpClientSession(requestTimeout, trans, requestHandlers, notificationHandlers)); + this.mcpSession = mcpClientTransportProvider.getSession(); } /** @@ -290,6 +293,7 @@ public Mono closeGracefully() { // -------------------------- // Initialization // -------------------------- + /** * The initialization phase MUST be the first interaction between client and server. * During this phase, the client and server: @@ -380,6 +384,7 @@ public Mono ping() { // -------------------------- // Roots // -------------------------- + /** * Adds a new root to the client's root list. * @param root The root to add. @@ -461,9 +466,8 @@ public Mono rootsListChangedNotification() { private RequestHandler rootsListRequestHandler() { return params -> { @SuppressWarnings("unused") - McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + PaginatedRequest request = objectMapper.convertValue(params, new TypeReference<>() { + }); List roots = this.roots.values().stream().toList(); @@ -476,9 +480,8 @@ private RequestHandler rootsListRequestHandler() { // -------------------------- private RequestHandler samplingCreateMessageHandler() { return params -> { - McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + CreateMessageRequest request = objectMapper.convertValue(params, new TypeReference<>() { + }); return this.samplingHandler.apply(request); }; @@ -531,7 +534,7 @@ public Mono listTools(String cursor) { if (this.serverCapabilities.tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new PaginatedRequest(cursor), LIST_TOOLS_RESULT_TYPE_REF); }); } @@ -588,7 +591,7 @@ public Mono listResources(String cursor) { if (this.serverCapabilities.resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new PaginatedRequest(cursor), LIST_RESOURCES_RESULT_TYPE_REF); }); } @@ -648,8 +651,8 @@ public Mono listResourceTemplates(String if (this.serverCapabilities.resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, - new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); + return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new PaginatedRequest(cursor), + LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); } @@ -695,16 +698,16 @@ private NotificationHandler asyncResourcesChangeNotificationHandler( // -------------------------- // Prompts // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { }; - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { }; /** * Retrieves the list of all prompts provided by the server. * @return A Mono that completes with the list of prompts result. - * @see McpSchema.ListPromptsResult + * @see ListPromptsResult * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts() { @@ -715,7 +718,7 @@ public Mono listPrompts() { * Retrieves a paginated list of prompts provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return A Mono that completes with the list of prompts result. - * @see McpSchema.ListPromptsResult + * @see ListPromptsResult * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { @@ -728,8 +731,8 @@ public Mono listPrompts(String cursor) { * including all parameters and instructions for generating AI content. * @param getPromptRequest The request containing the ID of the prompt to retrieve. * @return A Mono that completes with the prompt result. - * @see McpSchema.GetPromptRequest - * @see McpSchema.GetPromptResult + * @see GetPromptRequest + * @see GetPromptResult * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { @@ -763,8 +766,8 @@ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { return params -> { - McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { + LoggingMessageNotification loggingMessageNotification = objectMapper.convertValue(params, + new TypeReference<>() { }); return Flux.fromIterable(loggingConsumers) @@ -778,7 +781,7 @@ private NotificationHandler asyncLoggingNotificationHandler( * will only receive log messages at or above the specified severity level. * @param loggingLevel The minimum logging level to receive. * @return A Mono that completes when the logging level is set. - * @see McpSchema.LoggingLevel + * @see LoggingLevel */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { if (loggingLevel == null) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index f7b17961..6237cb8b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -12,7 +12,9 @@ import java.util.function.Consumer; import java.util.function.Function; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -52,7 +54,7 @@ * .requestTimeout(Duration.ofSeconds(5)) * .build(); * } - * + *

* Example of creating a basic asynchronous client:

{@code
  * McpClient.async(transport)
  *     .requestTimeout(Duration.ofSeconds(5))
@@ -114,7 +116,7 @@ public interface McpClient {
 	 * @return A new builder instance for configuring the client
 	 * @throws IllegalArgumentException if transport is null
 	 */
-	static SyncSpec sync(McpClientTransport transport) {
+	static SyncSpec sync(McpClientTransportProvider transport) {
 		return new SyncSpec(transport);
 	}
 
@@ -131,7 +133,7 @@ static SyncSpec sync(McpClientTransport transport) {
 	 * @return A new builder instance for configuring the client
 	 * @throws IllegalArgumentException if transport is null
 	 */
-	static AsyncSpec async(McpClientTransport transport) {
+	static AsyncSpec async(McpClientTransportProvider transport) {
 		return new AsyncSpec(transport);
 	}
 
@@ -153,7 +155,9 @@ static AsyncSpec async(McpClientTransport transport) {
 	 */
 	class SyncSpec {
 
-		private final McpClientTransport transport;
+		private final McpClientTransportProvider transportProvider;
+
+		private ObjectMapper objectMapper;
 
 		private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
 
@@ -175,9 +179,9 @@ class SyncSpec {
 
 		private Function samplingHandler;
 
-		private SyncSpec(McpClientTransport transport) {
-			Assert.notNull(transport, "Transport must not be null");
-			this.transport = transport;
+		private SyncSpec(McpClientTransportProvider transportProvider) {
+			Assert.notNull(transportProvider, "Transport must not be null");
+			this.transportProvider = transportProvider;
 		}
 
 		/**
@@ -367,9 +371,10 @@ public McpSyncClient build() {
 					this.loggingConsumers, this.samplingHandler);
 
 			McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
+			var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
 
-			return new McpSyncClient(
-					new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures));
+			return new McpSyncClient(new McpAsyncClient(this.transportProvider, this.requestTimeout,
+					this.initializationTimeout, asyncFeatures, mapper));
 		}
 
 	}
@@ -392,7 +397,9 @@ public McpSyncClient build() {
 	 */
 	class AsyncSpec {
 
-		private final McpClientTransport transport;
+		private final McpClientTransportProvider transportProvider;
+
+		private ObjectMapper objectMapper;
 
 		private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout
 
@@ -414,9 +421,9 @@ class AsyncSpec {
 
 		private Function> samplingHandler;
 
-		private AsyncSpec(McpClientTransport transport) {
-			Assert.notNull(transport, "Transport must not be null");
-			this.transport = transport;
+		private AsyncSpec(McpClientTransportProvider transportProvider) {
+			Assert.notNull(transportProvider, "Transport must not be null");
+			this.transportProvider = transportProvider;
 		}
 
 		/**
@@ -603,10 +610,12 @@ public AsyncSpec loggingConsumers(
 		 * @return a new instance of {@link McpAsyncClient}.
 		 */
 		public McpAsyncClient build() {
-			return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
+			var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
+			return new McpAsyncClient(this.transportProvider, this.requestTimeout, this.initializationTimeout,
 					new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
 							this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers,
-							this.loggingConsumers, this.samplingHandler));
+							this.loggingConsumers, this.samplingHandler),
+					mapper);
 		}
 
 	}
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
deleted file mode 100644
index 632d3844..00000000
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
+++ /dev/null
@@ -1,466 +0,0 @@
-/*
- * Copyright 2024 - 2024 the original author or authors.
- */
-package io.modelcontextprotocol.client.transport;
-
-import java.io.IOException;
-import java.net.URI;
-import java.net.http.HttpClient;
-import java.net.http.HttpRequest;
-import java.net.http.HttpResponse;
-import java.time.Duration;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
-import java.util.function.Function;
-
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
-import io.modelcontextprotocol.spec.McpClientTransport;
-import io.modelcontextprotocol.spec.McpError;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
-import io.modelcontextprotocol.util.Assert;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Mono;
-
-/**
- * Server-Sent Events (SSE) implementation of the
- * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE
- * transport specification, using Java's HttpClient.
- *
- * 

- * This transport implementation establishes a bidirectional communication channel between - * client and server using SSE for server-to-client messages and HTTP POST requests for - * client-to-server messages. The transport: - *

    - *
  • Establishes an SSE connection to receive server messages
  • - *
  • Handles endpoint discovery through SSE events
  • - *
  • Manages message serialization/deserialization using Jackson
  • - *
  • Provides graceful connection termination
  • - *
- * - *

- * The transport supports two types of SSE events: - *

    - *
  • 'endpoint' - Contains the URL for sending client messages
  • - *
  • 'message' - Contains JSON-RPC message payload
  • - *
- * - * @author Christian Tzolov - * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.McpClientTransport - */ -public class HttpClientSseClientTransport implements McpClientTransport { - - private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); - - /** SSE event type for JSON-RPC messages */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - /** SSE event type for endpoint discovery */ - private static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** Default SSE endpoint path */ - private static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** Base URI for the MCP server */ - private final String baseUri; - - /** SSE endpoint path */ - private final String sseEndpoint; - - /** SSE client for handling server-sent events. Uses the /sse endpoint */ - private final FlowSseClient sseClient; - - /** - * HTTP client for sending messages to the server. Uses HTTP POST over the message - * endpoint - */ - private final HttpClient httpClient; - - /** HTTP request builder for building requests to send messages to the server */ - private final HttpRequest.Builder requestBuilder; - - /** JSON object mapper for message serialization/deserialization */ - protected ObjectMapper objectMapper; - - /** Flag indicating if the transport is in closing state */ - private volatile boolean isClosing = false; - - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); - - /** Holds the discovered message endpoint URL */ - private final AtomicReference messageEndpoint = new AtomicReference<>(); - - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); - - /** - * Creates a new transport instance with default HTTP client and object mapper. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(String baseUri) { - this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { - this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, - ObjectMapper objectMapper) { - this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param clientBuilder the HTTP client builder to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, - String baseUri, String sseEndpoint, ObjectMapper objectMapper) { - this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param httpClient the HTTP client to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - */ - HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.hasText(baseUri, "baseUri must not be empty"); - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(httpClient, "httpClient must not be null"); - Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.baseUri = baseUri; - this.sseEndpoint = sseEndpoint; - this.objectMapper = objectMapper; - this.httpClient = httpClient; - this.requestBuilder = requestBuilder; - - this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); - } - - /** - * Creates a new builder for {@link HttpClientSseClientTransport}. - * @param baseUri the base URI of the MCP server - * @return a new builder instance - */ - public static Builder builder(String baseUri) { - return new Builder().baseUri(baseUri); - } - - /** - * Builder for {@link HttpClientSseClientTransport}. - */ - public static class Builder { - - private String baseUri; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private HttpClient.Builder clientBuilder = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_1_1) - .connectTimeout(Duration.ofSeconds(10)); - - private ObjectMapper objectMapper = new ObjectMapper(); - - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() - .header("Content-Type", "application/json"); - - /** - * Creates a new builder instance. - */ - Builder() { - // Default constructor - } - - /** - * Creates a new builder with the specified base URI. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. - * This constructor is deprecated and will be removed or made {@code protected} or - * {@code private} in a future release. - */ - @Deprecated(forRemoval = true) - public Builder(String baseUri) { - Assert.hasText(baseUri, "baseUri must not be empty"); - this.baseUri = baseUri; - } - - /** - * Sets the base URI. - * @param baseUri the base URI - * @return this builder - */ - Builder baseUri(String baseUri) { - Assert.hasText(baseUri, "baseUri must not be empty"); - this.baseUri = baseUri; - return this; - } - - /** - * Sets the SSE endpoint path. - * @param sseEndpoint the SSE endpoint path - * @return this builder - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the HTTP client builder. - * @param clientBuilder the HTTP client builder - * @return this builder - */ - public Builder clientBuilder(HttpClient.Builder clientBuilder) { - Assert.notNull(clientBuilder, "clientBuilder must not be null"); - this.clientBuilder = clientBuilder; - return this; - } - - /** - * Customizes the HTTP client builder. - * @param clientCustomizer the consumer to customize the HTTP client builder - * @return this builder - */ - public Builder customizeClient(final Consumer clientCustomizer) { - Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); - clientCustomizer.accept(clientBuilder); - return this; - } - - /** - * Sets the HTTP request builder. - * @param requestBuilder the HTTP request builder - * @return this builder - */ - public Builder requestBuilder(HttpRequest.Builder requestBuilder) { - Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.requestBuilder = requestBuilder; - return this; - } - - /** - * Customizes the HTTP client builder. - * @param requestCustomizer the consumer to customize the HTTP request builder - * @return this builder - */ - public Builder customizeRequest(final Consumer requestCustomizer) { - Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); - requestCustomizer.accept(requestBuilder); - return this; - } - - /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper - * @return this builder - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds a new {@link HttpClientSseClientTransport} instance. - * @return a new transport instance - */ - public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); - } - - } - - /** - * Establishes the SSE connection with the server and sets up message handling. - * - *

- * This method: - *

    - *
  • Initiates the SSE connection
  • - *
  • Handles endpoint discovery events
  • - *
  • Processes incoming JSON-RPC messages
  • - *
- * @param handler the function to process received JSON-RPC messages - * @return a Mono that completes when the connection is established - */ - @Override - public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); - - sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { - @Override - public void onEvent(SseEvent event) { - if (isClosing) { - return; - } - - try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); - } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); - } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); - } - } - catch (IOException e) { - logger.error("Error processing SSE event", e); - future.completeExceptionally(e); - } - } - - @Override - public void onError(Throwable error) { - if (!isClosing) { - logger.error("SSE connection error", error); - future.completeExceptionally(error); - } - } - }); - - return Mono.fromFuture(future); - } - - /** - * Sends a JSON-RPC message to the server. - * - *

- * This method waits for the message endpoint to be discovered before sending the - * message. The message is serialized to JSON and sent as an HTTP POST request. - * @param message the JSON-RPC message to send - * @return a Mono that completes when the message is sent - * @throws McpError if the message endpoint is not available or the wait times out - */ - @Override - public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { - return Mono.empty(); - } - - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } - - try { - String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } - } - - /** - * Gracefully closes the transport connection. - * - *

- * Sets the closing flag and cancels any pending connection future. This prevents new - * messages from being sent and allows ongoing operations to complete. - * @return a Mono that completes when the closing process is initiated - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); - } - }); - } - - /** - * Unmarshal data to the specified type using the configured object mapper. - * @param data the data to unmarshal - * @param typeRef the type reference for the target type - * @param the target type - * @return the unmarshalled object - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java new file mode 100644 index 00000000..84f68481 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java @@ -0,0 +1,518 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static io.modelcontextprotocol.util.Utils.getSessionIdFromUrl; + +/** + * Server-Sent Events (SSE) implementation of the + * {@link io.modelcontextprotocol.spec.McpClientTransportProvider} that provides SSE + * client transport that follows the MCP HTTP with SSE transport specification, using + * Java's HttpClient. + * + * @author Christian Tzolov + * @author Jermaine Hua + * @see io.modelcontextprotocol.spec.McpClientTransportProvider + */ +public class HttpClientSseClientTransportProvider implements McpClientTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransportProvider.class); + + /** + * SSE event type for JSON-RPC messages + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * SSE event type for endpoint discovery + */ + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path + */ + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** + * Base URI for the MCP server + */ + private final String baseUri; + + /** + * SSE endpoint path + */ + private final String sseEndpoint; + + /** + * JSON object mapper for message serialization/deserialization + */ + protected ObjectMapper objectMapper; + + /** + * Flag indicating if the transport is in closing state + */ + private volatile boolean isClosing = false; + + /** + * Latch for coordinating endpoint discovery + */ + private final CountDownLatch closeLatch = new CountDownLatch(1); + + /** + * HTTP request builder for building requests to send messages to the server + */ + private final HttpRequest.Builder requestBuilder; + + private final HttpClient.Builder clientBuilder; + + /** + * Session factory for creating new sessions + */ + private McpClientSession.Factory sessionFactory; + + /** + * Active client session + */ + private McpClientSession session; + + /** + * Creates a new transport instance with default HTTP client and object mapper. + * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)} + * instead. This constructor will be removed in future versions. + */ + @Deprecated(forRemoval = true) + public HttpClientSseClientTransportProvider(String baseUri) { + this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); + } + + /** + * Creates a new transport instance with custom HTTP client builder and object mapper. + * @param clientBuilder the HTTP client builder to use + * @param baseUri the base URI of the MCP server + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)} + * instead. This constructor will be removed in future versions. + */ + @Deprecated(forRemoval = true) + public HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, String baseUri, + ObjectMapper objectMapper) { + this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder and object mapper. + * @param clientBuilder the HTTP client builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)} + * instead. This constructor will be removed in future versions. + */ + @Deprecated(forRemoval = true) + public HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, + ObjectMapper objectMapper) { + this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param clientBuilder the HTTP client builder to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + * @deprecated Use {@link HttpClientSseClientTransportProvider#builder(String)} + * instead. This constructor will be removed in future versions. + */ + HttpClientSseClientTransportProvider(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, + String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.hasText(baseUri, "baseUri must not be empty"); + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.baseUri = baseUri; + this.sseEndpoint = sseEndpoint; + this.objectMapper = objectMapper; + this.requestBuilder = requestBuilder; + this.clientBuilder = clientBuilder; + } + + /** + * Creates a new builder for {@link HttpClientSseClientTransport}. + * @param baseUri the base URI of the MCP server + * @return a new builder instance + */ + public static Builder builder(String baseUri) { + return new Builder().baseUri(baseUri); + } + + /** + * Builder for {@link HttpClientSseClientTransportProvider}. + */ + public static class Builder { + + private String baseUri; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .header("Content-Type", "application/json"); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. + */ + @Deprecated(forRemoval = true) + public Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Sets the object mapper for JSON serialization/deserialization. + * @param objectMapper the object mapper + * @return this builder + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a new {@link HttpClientSseClientTransportProvider} instance. + * @return a new transport instance + */ + public HttpClientSseClientTransportProvider build() { + return new HttpClientSseClientTransportProvider(clientBuilder, requestBuilder, baseUri, sseEndpoint, + objectMapper); + } + + } + + @Override + public void setSessionFactory(McpClientSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Gracefully closes the transport connection. + * + *

+ * Sets the closing flag and cancels any pending connection future. This prevents new + * messages from being sent and allows ongoing operations to complete. + * @return a Mono that completes when the closing process is initiated + */ + @Override + public Mono closeGracefully() { + isClosing = true; + return session.closeGracefully().then(); + } + + @Override + public McpClientSession getSession() { + if (session != null) { + return session; + } + + HttpClient httpClient = clientBuilder.build(); + FlowSseClient sseClient = new FlowSseClient(httpClient, requestBuilder); + McpClientTransport mcpClientTransport = new HttpClientSseClientTransport(sseClient, httpClient); + session = sessionFactory.create(mcpClientTransport); + return session; + } + + private class HttpClientSseClientTransport implements McpClientTransport { + + /** + * SSE client for handling server-sent events. Uses the /sse endpoint + */ + private final FlowSseClient sseClient; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** + * Holds the discovered message endpoint URL + */ + private final AtomicReference messageEndpoint = new AtomicReference<>(); + + /** + * Holds the SSE connection future + */ + private final AtomicReference> connectionFuture = new AtomicReference<>(); + + public HttpClientSseClientTransport(FlowSseClient sseClient, HttpClient httpClient) { + this.sseClient = sseClient; + this.httpClient = httpClient; + } + + /** + * Establishes the SSE connection with the server and sets up message handling. + * + *

+ * This method: + *

    + *
  • Initiates the SSE connection
  • + *
  • Handles endpoint discovery events
  • + *
  • Processes incoming JSON-RPC messages
  • + *
+ * @param handler the function to process received JSON-RPC messages + * @return a Mono that completes when the connection is established + */ + @Override + public Mono connect(Function, Mono> handler) { + return connect(); + } + + @Override + public Mono connect() { + CompletableFuture future = new CompletableFuture<>(); + connectionFuture.set(future); + sseClient.subscribe(baseUri + sseEndpoint, new FlowSseClient.SseEventHandler() { + @Override + public void onEvent(SseEvent event) { + if (isClosing) { + return; + } + + try { + if (ENDPOINT_EVENT_TYPE.equals(event.type())) { + String endpoint = event.data(); + messageEndpoint.set(endpoint); + closeLatch.countDown(); + future.complete(null); + } + else if (MESSAGE_EVENT_TYPE.equals(event.type())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); + session.handle(message).subscribe(); + } + else { + logger.error("Received unrecognized SSE event type: {}", event.type()); + } + } + catch (IOException e) { + logger.error("Error processing SSE event", e); + future.completeExceptionally(e); + } + } + + @Override + public void onError(Throwable error) { + if (!isClosing) { + logger.error("SSE connection error", error); + future.completeExceptionally(error); + } + } + }); + + return Mono.fromFuture(future); + } + + /** + * Sends a JSON-RPC message to the server. + * + *

+ * This method waits for the message endpoint to be discovered before sending the + * message. The message is serialized to JSON and sent as an HTTP POST request. + * @param message the JSON-RPC message to send + * @return a Mono that completes when the message is sent + * @throws McpError if the message endpoint is not available or the wait times out + */ + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (isClosing) { + return Mono.empty(); + } + + try { + if (!closeLatch.await(10, TimeUnit.SECONDS)) { + return Mono.error(new McpError("Failed to wait for the message endpoint")); + } + } + catch (InterruptedException e) { + return Mono.error(new McpError("Failed to wait for the message endpoint")); + } + + String endpoint = messageEndpoint.get(); + if (endpoint == null) { + return Mono.error(new McpError("No message endpoint available")); + } + + try { + String jsonText = objectMapper.writeValueAsString(message); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(baseUri + endpoint)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonText)) + .build(); + + return Mono.fromFuture( + httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 + && response.statusCode() != 202 && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + })); + } + catch (IOException e) { + if (!isClosing) { + return Mono.error(new RuntimeException("Failed to serialize message", e)); + } + return Mono.empty(); + } + } + + /** + * Gracefully closes the transport connection. + * + *

+ * Sets the closing flag and cancels any pending connection future. This prevents + * new messages from being sent and allows ongoing operations to complete. + * @return a Mono that completes when the closing process is initiated + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + CompletableFuture future = connectionFuture.get(); + if (future != null && !future.isDone()) { + future.cancel(true); + } + }); + } + + /** + * Unmarshal data to the specified type using the configured object mapper. + * @param data the data to unmarshal + * @param typeRef the type reference for the target type + * @param the target type + * @return the unmarshalled object + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java deleted file mode 100644 index 9d71cbb4..00000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Executors; -import java.util.function.Consumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; - -/** - * Implementation of the MCP Stdio transport that communicates with a server process using - * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC - * messages over stdin/stdout, with errors and debug information sent to stderr. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -public class StdioClientTransport implements McpClientTransport { - - private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); - - private final Sinks.Many inboundSink; - - private final Sinks.Many outboundSink; - - /** The server process being communicated with */ - private Process process; - - private ObjectMapper objectMapper; - - /** Scheduler for handling inbound messages from the server process */ - private Scheduler inboundScheduler; - - /** Scheduler for handling outbound messages to the server process */ - private Scheduler outboundScheduler; - - /** Scheduler for handling error messages from the server process */ - private Scheduler errorScheduler; - - /** Parameters for configuring and starting the server process */ - private final ServerParameters params; - - private final Sinks.Many errorSink; - - private volatile boolean isClosing = false; - - // visible for tests - private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); - - /** - * Creates a new StdioClientTransport with the specified parameters and default - * ObjectMapper. - * @param params The parameters for configuring the server process - */ - public StdioClientTransport(ServerParameters params) { - this(params, new ObjectMapper()); - } - - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { - Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); - - this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); - this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - - this.params = params; - - this.objectMapper = objectMapper; - - this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); - - // Start threads - this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound"); - this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); - this.errorScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "error"); - } - - /** - * Starts the server process and initializes the message processing streams. This - * method sets up the process with the configured command, arguments, and environment, - * then starts the inbound, outbound, and error processing threads. - * @throws RuntimeException if the process fails to start or if the process streams - * are null - */ - @Override - public Mono connect(Function, Mono> handler) { - return Mono.fromRunnable(() -> { - handleIncomingMessages(handler); - handleIncomingErrors(); - - // Prepare command and environment - List fullCommand = new ArrayList<>(); - fullCommand.add(params.getCommand()); - fullCommand.addAll(params.getArgs()); - - ProcessBuilder processBuilder = this.getProcessBuilder(); - processBuilder.command(fullCommand); - processBuilder.environment().putAll(params.getEnv()); - - // Start the process - try { - this.process = processBuilder.start(); - } - catch (IOException e) { - throw new RuntimeException("Failed to start process with command: " + fullCommand, e); - } - - // Validate process streams - if (this.process.getInputStream() == null || process.getOutputStream() == null) { - this.process.destroy(); - throw new RuntimeException("Process input or output stream is null"); - } - - // Start threads - startInboundProcessing(); - startOutboundProcessing(); - startErrorProcessing(); - }).subscribeOn(Schedulers.boundedElastic()); - } - - /** - * Creates and returns a new ProcessBuilder instance. Protected to allow overriding in - * tests. - * @return A new ProcessBuilder instance - */ - protected ProcessBuilder getProcessBuilder() { - return new ProcessBuilder(); - } - - /** - * Sets the handler for processing transport-level errors. - * - *

- * The provided handler will be called when errors occur during transport operations, - * such as connection failures or protocol violations. - *

- * @param errorHandler a consumer that processes error messages - */ - public void setStdErrorHandler(Consumer errorHandler) { - this.stdErrorHandler = errorHandler; - } - - /** - * Waits for the server process to exit. - * @throws RuntimeException if the process is interrupted while waiting - */ - public void awaitForExit() { - try { - this.process.waitFor(); - } - catch (InterruptedException e) { - throw new RuntimeException("Process interrupted", e); - } - } - - /** - * Starts the error processing thread that reads from the process's error stream. - * Error messages are logged and emitted to the error sink. - */ - private void startErrorProcessing() { - this.errorScheduler.schedule(() -> { - try (BufferedReader processErrorReader = new BufferedReader( - new InputStreamReader(process.getErrorStream()))) { - String line; - while (!isClosing && (line = processErrorReader.readLine()) != null) { - try { - if (!this.errorSink.tryEmitNext(line).isSuccess()) { - if (!isClosing) { - logger.error("Failed to emit error message"); - } - break; - } - } - catch (Exception e) { - if (!isClosing) { - logger.error("Error processing error message", e); - } - break; - } - } - } - catch (IOException e) { - if (!isClosing) { - logger.error("Error reading from error stream", e); - } - } - finally { - isClosing = true; - errorSink.tryEmitComplete(); - } - }); - } - - private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { - this.inboundSink.asFlux() - .flatMap(message -> Mono.just(message) - .transform(inboundMessageHandler) - .contextWrite(ctx -> ctx.put("observation", "myObservation"))) - .subscribe(); - } - - private void handleIncomingErrors() { - this.errorSink.asFlux().subscribe(e -> { - this.stdErrorHandler.accept(e); - }); - } - - @Override - public Mono sendMessage(JSONRPCMessage message) { - if (this.outboundSink.tryEmitNext(message).isSuccess()) { - // TODO: essentially we could reschedule ourselves in some time and make - // another attempt with the already read data but pause reading until - // success - // In this approach we delegate the retry and the backpressure onto the - // caller. This might be enough for most cases. - return Mono.empty(); - } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); - } - } - - /** - * Starts the inbound processing thread that reads JSON-RPC messages from the - * process's input stream. Messages are deserialized and emitted to the inbound sink. - */ - private void startInboundProcessing() { - this.inboundScheduler.schedule(() -> { - try (BufferedReader processReader = new BufferedReader(new InputStreamReader(process.getInputStream()))) { - String line; - while (!isClosing && (line = processReader.readLine()) != null) { - try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); - if (!this.inboundSink.tryEmitNext(message).isSuccess()) { - if (!isClosing) { - logger.error("Failed to enqueue inbound message: {}", message); - } - break; - } - } - catch (Exception e) { - if (!isClosing) { - logger.error("Error processing inbound message for line: " + line, e); - } - break; - } - } - } - catch (IOException e) { - if (!isClosing) { - logger.error("Error reading from input stream", e); - } - } - finally { - isClosing = true; - inboundSink.tryEmitComplete(); - } - }); - } - - /** - * Starts the outbound processing thread that writes JSON-RPC messages to the - * process's output stream. Messages are serialized to JSON and written with a newline - * delimiter. - */ - private void startOutboundProcessing() { - this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads, and we - // want to ensure that the actual writing happens on a dedicated thread - .publishOn(outboundScheduler) - .handle((message, s) -> { - if (message != null && !isClosing) { - try { - String jsonMessage = objectMapper.writeValueAsString(message); - // Escape any embedded newlines in the JSON message as per spec: - // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio - // - Messages are delimited by newlines, and MUST NOT contain - // embedded newlines. - jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); - - var os = this.process.getOutputStream(); - synchronized (os) { - os.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); - os.write("\n".getBytes(StandardCharsets.UTF_8)); - os.flush(); - } - s.next(message); - } - catch (IOException e) { - s.error(new RuntimeException(e)); - } - } - })); - } - - protected void handleOutbound(Function, Flux> outboundConsumer) { - outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { - isClosing = true; - outboundSink.tryEmitComplete(); - }).doOnError(e -> { - if (!isClosing) { - logger.error("Error in outbound processing", e); - isClosing = true; - outboundSink.tryEmitComplete(); - } - }).subscribe(); - } - - /** - * Gracefully closes the transport by destroying the process and disposing of the - * schedulers. This method sends a TERM signal to the process and waits for it to exit - * before cleaning up resources. - * @return A Mono that completes when the transport is closed - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown"); - }).then(Mono.defer(() -> { - // First complete all sinks to stop accepting new messages - inboundSink.tryEmitComplete(); - outboundSink.tryEmitComplete(); - errorSink.tryEmitComplete(); - - // Give a short time for any pending messages to be processed - return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.defer(() -> { - logger.debug("Sending TERM to process"); - if (this.process != null) { - this.process.destroy(); - return Mono.fromFuture(process.onExit()); - } - else { - logger.warn("Process not started"); - return Mono.empty(); - } - })).doOnNext(process -> { - if (process.exitValue() != 0) { - logger.warn("Process terminated with code " + process.exitValue()); - } - }).then(Mono.fromRunnable(() -> { - try { - // The Threads are blocked on readLine so disposeGracefully would not - // interrupt them, therefore we issue an async hard dispose. - inboundScheduler.dispose(); - errorScheduler.dispose(); - outboundScheduler.dispose(); - - logger.debug("Graceful shutdown completed"); - } - catch (Exception e) { - logger.error("Error during graceful shutdown", e); - } - })).then().subscribeOn(Schedulers.boundedElastic()); - } - - public Sinks.Many getErrorSink() { - return this.errorSink; - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java new file mode 100644 index 00000000..f8ec7abb --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java @@ -0,0 +1,459 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Implementation of a transport provider that provides the MCP Stdio transport that + * communicates with a server process using standard input/output streams. Messages are + * exchanged as newline-delimited JSON-RPC messages over stdin/stdout, and errors and + * debug information are sent to stderr. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jermaine Hua + * @see io.modelcontextprotocol.spec.McpClientTransportProvider + */ +public class StdioClientTransportProvider implements McpClientTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioClientTransportProvider.class); + + private final ObjectMapper objectMapper; + + /** Parameters for configuring and starting the server process */ + private final ServerParameters params; + + private volatile boolean isClosing = false; + + // visible for tests + private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); + + private final Sinks.Many errorSink; + + /** + * Session factory for creating new sessions + */ + private McpClientSession.Factory sessionFactory; + + /** + * Active client session + */ + private McpClientSession session; + + /** + * Creates a new StdioClientTransportProvider with the specified parameters and + * default ObjectMapper. + * @param params The parameters for configuring the server process + */ + public StdioClientTransportProvider(ServerParameters params) { + this(params, new ObjectMapper()); + } + + /** + * Creates a new StdioClientTransportProvider with the specified parameters and + * ObjectMapper. + * @param params The parameters for configuring the server process + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public StdioClientTransportProvider(ServerParameters params, ObjectMapper objectMapper) { + Assert.notNull(params, "The params can not be null"); + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + this.objectMapper = objectMapper; + this.params = params; + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + } + + @Override + public void setSessionFactory(McpClientSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public McpClientSession getSession() { + if (session != null) { + return session; + } + McpClientTransport mcpClientTransport = new StdioClientTransport(params); + this.session = sessionFactory.create(mcpClientTransport); + return session; + } + + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. + *

+ * @param errorHandler a consumer that processes error messages + */ + public void setStdErrorHandler(Consumer errorHandler) { + stdErrorHandler = errorHandler; + } + + public Sinks.Many getErrorSink() { + return errorSink; + } + + @Override + public Mono closeGracefully() { + return session.closeGracefully(); + } + + private class StdioClientTransport implements McpClientTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + /** The server process being communicated with */ + private Process process; + + /** Scheduler for handling inbound messages from the server process */ + private Scheduler inboundScheduler; + + /** Scheduler for handling outbound messages to the server process */ + private Scheduler outboundScheduler; + + /** Scheduler for handling error messages from the server process */ + private Scheduler errorScheduler; + + /** + * Creates a new StdioClientTransport with the specified parameters and default + * ObjectMapper. + * @param params The parameters for configuring the server process + */ + public StdioClientTransport(ServerParameters params) { + this(params, new ObjectMapper()); + } + + /** + * Creates a new StdioClientTransport with the specified parameters and + * ObjectMapper. + * @param params The parameters for configuring the server process + * @param objectMapper The ObjectMapper to use for JSON + * serialization/deserialization + */ + public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { + Assert.notNull(params, "The params can not be null"); + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Start threads + this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + this.errorScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "error"); + } + + /** + * Starts the server process and initializes the message processing streams. This + * method sets up the process with the configured command, arguments, and + * environment, then starts the inbound, outbound, and error processing threads. + * @throws RuntimeException if the process fails to start or if the process + * streams are null + */ + @Override + public Mono connect(Function, Mono> handler) { + return connect(); + } + + @Override + public Mono connect() { + return Mono.fromRunnable(() -> { + this.inboundSink.asFlux() + .flatMap(message -> session.handle(message) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + handleIncomingErrors(); + + // Prepare command and environment + List fullCommand = new ArrayList<>(); + fullCommand.add(params.getCommand()); + fullCommand.addAll(params.getArgs()); + + ProcessBuilder processBuilder = this.getProcessBuilder(); + processBuilder.command(fullCommand); + processBuilder.environment().putAll(params.getEnv()); + + // Start the process + try { + this.process = processBuilder.start(); + } + catch (IOException e) { + throw new RuntimeException("Failed to start process with command: " + fullCommand, e); + } + + // Validate process streams + if (this.process.getInputStream() == null || process.getOutputStream() == null) { + this.process.destroy(); + throw new RuntimeException("Process input or output stream is null"); + } + + // Start threads + startInboundProcessing(); + startOutboundProcessing(); + startErrorProcessing(); + }).subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Creates and returns a new ProcessBuilder instance. Protected to allow + * overriding in tests. + * @return A new ProcessBuilder instance + */ + protected ProcessBuilder getProcessBuilder() { + return new ProcessBuilder(); + } + + /** + * Waits for the server process to exit. + * @throws RuntimeException if the process is interrupted while waiting + */ + public void awaitForExit() { + try { + this.process.waitFor(); + } + catch (InterruptedException e) { + throw new RuntimeException("Process interrupted", e); + } + } + + /** + * Starts the error processing thread that reads from the process's error stream. + * Error messages are logged and emitted to the error sink. + */ + private void startErrorProcessing() { + this.errorScheduler.schedule(() -> { + try (BufferedReader processErrorReader = new BufferedReader( + new InputStreamReader(process.getErrorStream()))) { + String line; + while (!isClosing && (line = processErrorReader.readLine()) != null) { + try { + if (!errorSink.tryEmitNext(line).isSuccess()) { + if (!isClosing) { + logger.error("Failed to emit error message"); + } + break; + } + } + catch (Exception e) { + if (!isClosing) { + logger.error("Error processing error message", e); + } + break; + } + } + } + catch (IOException e) { + if (!isClosing) { + logger.error("Error reading from error stream", e); + } + } + finally { + isClosing = true; + errorSink.tryEmitComplete(); + } + }); + } + + private void handleIncomingMessages( + Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap( + message -> session.handle(message).contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + private void handleIncomingErrors() { + errorSink.asFlux().subscribe(e -> { + stdErrorHandler.accept(e); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (this.outboundSink.tryEmitNext(message).isSuccess()) { + // TODO: essentially we could reschedule ourselves in some time and make + // another attempt with the already read data but pause reading until + // success + // In this approach we delegate the retry and the backpressure onto the + // caller. This might be enough for most cases. + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from the + * process's input stream. Messages are deserialized and emitted to the inbound + * sink. + */ + private void startInboundProcessing() { + this.inboundScheduler.schedule(() -> { + try (BufferedReader processReader = new BufferedReader( + new InputStreamReader(process.getInputStream()))) { + String line; + while (!isClosing && (line = processReader.readLine()) != null) { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, line); + if (!this.inboundSink.tryEmitNext(message).isSuccess()) { + if (!isClosing) { + logger.error("Failed to enqueue inbound message: {}", message); + } + break; + } + } + catch (Exception e) { + if (!isClosing) { + logger.error("Error processing inbound message for line: " + line, e); + } + break; + } + } + } + catch (IOException e) { + if (!isClosing) { + logger.error("Error reading from input stream", e); + } + } + finally { + isClosing = true; + inboundSink.tryEmitComplete(); + } + }); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to the + * process's output stream. Messages are serialized to JSON and written with a + * newline delimiter. + */ + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, s) -> { + if (message != null && !isClosing) { + try { + String jsonMessage = objectMapper.writeValueAsString(message); + // Escape any embedded newlines in the JSON message as per + // spec: + // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio + // - Messages are delimited by newlines, and MUST NOT contain + // embedded newlines. + jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); + + var os = this.process.getOutputStream(); + synchronized (os) { + os.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); + os.write("\n".getBytes(StandardCharsets.UTF_8)); + os.flush(); + } + s.next(message); + } + catch (IOException e) { + s.error(new RuntimeException(e)); + } + } + })); + } + + protected void handleOutbound( + Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to + * exit before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + errorSink.tryEmitComplete(); + + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)); + })).then(Mono.defer(() -> { + logger.debug("Sending TERM to process"); + if (this.process != null) { + this.process.destroy(); + return Mono.fromFuture(process.onExit()); + } + else { + logger.warn("Process not started"); + return Mono.empty(); + } + })).doOnNext(process -> { + if (process.exitValue() != 0) { + logger.warn("Process terminated with code " + process.exitValue()); + } + }).then(Mono.fromRunnable(() -> { + try { + // The Threads are blocked on readLine so disposeGracefully would not + // interrupt them, therefore we issue an async hard dispose. + inboundScheduler.dispose(); + errorScheduler.dispose(); + outboundScheduler.dispose(); + + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0895e02b..d0974582 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -17,6 +17,7 @@ import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import reactor.core.scheduler.Schedulers; /** * Default implementation of the MCP (Model Context Protocol) session that manages @@ -40,6 +41,10 @@ public class McpClientSession implements McpSession { /** Logger for this class */ private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); + private String id; + + private static final String DEFAULT_SESSION_ID = "-1"; + /** Duration to wait for request responses before timing out */ private final Duration requestTimeout; @@ -97,21 +102,30 @@ public interface NotificationHandler { } + public McpClientSession(Duration requestTimeout, McpClientTransport transport, + Map> requestHandlers, Map notificationHandlers) { + + this(DEFAULT_SESSION_ID, requestTimeout, transport, requestHandlers, notificationHandlers); + } + /** * Creates a new McpClientSession with the specified configuration and handlers. + * @param id Unique identifier for the session * @param requestTimeout Duration to wait for responses * @param transport Transport implementation for message exchange * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers */ - public McpClientSession(Duration requestTimeout, McpClientTransport transport, + public McpClientSession(String id, Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { + Assert.notNull(id, "The id can not be null"); Assert.notNull(requestTimeout, "The requestTimeout can not be null"); Assert.notNull(transport, "The transport can not be null"); Assert.notNull(requestHandlers, "The requestHandlers can not be null"); Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); + this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; this.requestHandlers.putAll(requestHandlers); @@ -122,7 +136,11 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + this.connection = this.transport.connect().subscribe(); + } + + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { if (message instanceof McpSchema.JSONRPCResponse response) { logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); @@ -132,23 +150,27 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, else { sink.success(response); } + return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), - error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - transport.sendMessage(errorResponse).subscribe(); - }); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification).subscribe(null, - error -> logger.error("Error handling notification: {}", error.getMessage())); + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } - })).subscribe(); + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); } /** @@ -285,4 +307,32 @@ public void close() { transport.close(); } + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public McpClientTransport getTransport() { + return transport; + } + + /** + * Factory for creating client sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the server. + * @return a new client session. + */ + McpClientSession create(McpClientTransport sessionTransport); + + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index f2909124..2c6af589 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -12,9 +12,14 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jermaine Hua */ public interface McpClientTransport extends McpTransport { Mono connect(Function, Mono> handler); + default Mono connect() { + return connect(mono -> mono); + }; + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java new file mode 100644 index 00000000..3cfb9af1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransportProvider.java @@ -0,0 +1,39 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * @author Jermaine Hua + */ +public interface McpClientTransportProvider { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpClientSession.Factory sessionFactory); + + /** + * Get the active session. + * @return active client session + */ + McpClientSession getSession(); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 0f799ca0..536c65aa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,7 +4,11 @@ package io.modelcontextprotocol.util; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLDecoder; import java.util.Collection; +import java.util.HashMap; import java.util.Map; import reactor.util.annotation.Nullable; @@ -52,4 +56,27 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + public static String getSessionIdFromUrl(String urlStr) { + URI uri; + try { + uri = new URI(urlStr); + } + catch (URISyntaxException e) { + return null; + } + String query = uri.getQuery(); + if (query == null) { + return null; + } + Map params = new HashMap<>(); + String[] pairs = query.split("&"); + for (String pair : pairs) { + int idx = pair.indexOf("="); + String key = (idx > 0) ? pair.substring(0, idx) : pair; + String value = (idx > 0 && pair.length() > idx + 1) ? pair.substring(idx + 1) : null; + params.put(key, value); + } + return params.get("sessionId"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java new file mode 100644 index 00000000..bf15283d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransportProvider.java @@ -0,0 +1,140 @@ +package io.modelcontextprotocol; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Function; + +/** + * @author Jermaine Hua + */ +public class MockMcpClientTransportProvider implements McpClientTransportProvider { + + private McpClientSession.Factory sessionFactory; + + private McpClientSession session; + + private final BiConsumer interceptor; + + private MockMcpClientTransport transport; + + public MockMcpClientTransportProvider() { + this((t, msg) -> { + }); + } + + public MockMcpClientTransportProvider(BiConsumer interceptor) { + this.transport = new MockMcpClientTransport(); + this.interceptor = interceptor; + } + + public MockMcpClientTransport getTransport() { + return transport; + } + + @Override + public void setSessionFactory(McpClientSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public McpClientSession getSession() { + if (session != null) { + return session; + } + session = sessionFactory.create(transport); + return session; + } + + @Override + public Mono closeGracefully() { + return session.closeGracefully(); + } + + public class MockMcpClientTransport implements McpClientTransport { + + private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); + + private final List sent = new ArrayList<>(); + + public MockMcpClientTransport() { + } + + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { + if (inbound.tryEmitNext(message).isFailure()) { + throw new RuntimeException("Failed to process incoming message " + message); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + sent.add(message); + interceptor.accept(this, message); + return Mono.empty(); + } + + public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { + return (McpSchema.JSONRPCRequest) getLastSentMessage(); + } + + public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() { + return (McpSchema.JSONRPCNotification) getLastSentMessage(); + } + + public McpSchema.JSONRPCMessage getLastSentMessage() { + return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; + } + + private volatile boolean connected = false; + + @Override + public Mono connect(Function, Mono> handler) { + if (connected) { + return Mono.error(new IllegalStateException("Already connected")); + } + connected = true; + return inbound.asFlux() + .flatMap(message -> session.handle(message)) + .doFinally(signal -> connected = false) + .then(); + } + + @Override + public Mono connect() { + if (connected) { + return Mono.error(new IllegalStateException("Already connected")); + } + connected = true; + return inbound.asFlux() + .flatMap(message -> session.handle(message)) + .doFinally(signal -> connected = false) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + connected = false; + inbound.tryEmitComplete(); + // Wait for all subscribers to complete + return Mono.empty(); + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return new ObjectMapper().convertValue(data, typeRef); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af..2a63ae9f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,15 +4,7 @@ package io.modelcontextprotocol.client; -import java.time.Duration; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -35,6 +27,14 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -51,7 +51,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected McpClientTransportProvider createMcpClientTransportProvider(); protected void onStart() { } @@ -67,15 +67,16 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(McpClientTransport transport) { - return client(transport, Function.identity()); + McpAsyncClient client(McpClientTransportProvider transportProvider) { + return client(transportProvider, Function.identity()); } - McpAsyncClient client(McpClientTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransportProvider transportProvider, + Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - McpClient.AsyncSpec builder = McpClient.async(transport) + McpClient.AsyncSpec builder = McpClient.async(transportProvider) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()); @@ -86,13 +87,13 @@ McpAsyncClient client(McpClientTransport transport, Function c) { - withClient(transport, Function.identity(), c); + void withClient(McpClientTransportProvider transportProvider, Consumer c) { + withClient(transportProvider, Function.identity(), c); } - void withClient(McpClientTransport transport, Function customizer, - Consumer c) { - var client = client(transport, customizer); + void withClient(McpClientTransportProvider transportProvider, + Function customizer, Consumer c) { + var client = client(transportProvider, customizer); try { c.accept(client); } @@ -112,7 +113,7 @@ void tearDown() { } void verifyInitializationTimeout(Function> operation, String action) { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) .expectSubscription() .thenAwait(getInitializationTimeout()) @@ -127,7 +128,7 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpClientTransportProvider()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @@ -139,7 +140,7 @@ void testListToolsWithoutInitialization() { @Test void testListTools() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) .consumeNextWith(result -> { assertThat(result.tools()).isNotNull().isNotEmpty(); @@ -159,7 +160,7 @@ void testPingWithoutInitialization() { @Test void testPing() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) .expectNextCount(1) .verifyComplete(); @@ -174,7 +175,7 @@ void testCallToolWithoutInitialization() { @Test void testCallTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) @@ -190,7 +191,7 @@ void testCallTool() { @Test void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); @@ -208,7 +209,7 @@ void testListResourcesWithoutInitialization() { @Test void testListResources() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) .consumeNextWith(resources -> { assertThat(resources).isNotNull().satisfies(result -> { @@ -227,7 +228,7 @@ void testListResources() { @Test void testMcpAsyncClientState() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { assertThat(mcpAsyncClient).isNotNull(); }); } @@ -239,7 +240,7 @@ void testListPromptsWithoutInitialization() { @Test void testListPrompts() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) .consumeNextWith(prompts -> { assertThat(prompts).isNotNull().satisfies(result -> { @@ -264,7 +265,7 @@ void testGetPromptWithoutInitialization() { @Test void testGetPrompt() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier .create(mcpAsyncClient.initialize() .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) @@ -286,7 +287,7 @@ void testRootsListChangedWithoutInitialization() { @Test void testRootsListChanged() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) .verifyComplete(); }); @@ -294,15 +295,15 @@ void testRootsListChanged() { @Test void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - client -> { + withClient(createMcpClientTransportProvider(), + builder -> builder.roots(new Root("file:///test/path", "test-root")), client -> { StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); }); } @Test void testAddRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { Root newRoot = new Root("file:///new/test/path", "new-test-root"); StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); }); @@ -310,7 +311,7 @@ void testAddRoot() { @Test void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) .verify(); @@ -319,7 +320,7 @@ void testAddRootWithNullValue() { @Test void testRemoveRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); @@ -329,7 +330,7 @@ void testRemoveRoot() { @Test void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) @@ -340,7 +341,7 @@ void testRemoveNonExistentRoot() { @Test @Disabled void testReadResource() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { if (!resources.resources().isEmpty()) { Resource firstResource = resources.resources().get(0); @@ -360,7 +361,7 @@ void testListResourceTemplatesWithoutInitialization() { @Test void testListResourceTemplates() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) .consumeNextWith(result -> { assertThat(result).isNotNull(); @@ -372,7 +373,7 @@ void testListResourceTemplates() { // @Test void testResourceSubscription() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { if (!resources.resources().isEmpty()) { Resource firstResource = resources.resources().get(0); @@ -395,7 +396,7 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer( @@ -415,7 +416,7 @@ void testInitializeWithSamplingCapability() { .message("test") .model("test-model") .build(); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), client -> { StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); @@ -433,8 +434,8 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), - client -> + withClient(createMcpClientTransportProvider(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { assertThat(result).isNotNull(); @@ -454,7 +455,7 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier .create(mcpAsyncClient.initialize() .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) @@ -466,7 +467,7 @@ void testLoggingLevels() { void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), client -> { StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); @@ -478,7 +479,7 @@ void testLoggingConsumer() { @Test void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpAsyncClient -> { + withClient(createMcpClientTransportProvider(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) .verify(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 24c161eb..6af2e2cd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected McpClientTransportProvider createMcpClientTransportProvider(); protected void onStart() { } @@ -66,15 +66,16 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(McpClientTransport transport) { - return client(transport, Function.identity()); + McpSyncClient client(McpClientTransportProvider transportProvider) { + return client(transportProvider, Function.identity()); } - McpSyncClient client(McpClientTransport transport, Function customizer) { + McpSyncClient client(McpClientTransportProvider transportProvider, + Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - McpClient.SyncSpec builder = McpClient.sync(transport) + McpClient.SyncSpec builder = McpClient.sync(transportProvider) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()); @@ -85,13 +86,13 @@ McpSyncClient client(McpClientTransport transport, Function c) { - withClient(transport, Function.identity(), c); + void withClient(McpClientTransportProvider transportProvider, Consumer c) { + withClient(transportProvider, Function.identity(), c); } - void withClient(McpClientTransport transport, Function customizer, - Consumer c) { - var client = client(transport, customizer); + void withClient(McpClientTransportProvider transportProvider, + Function customizer, Consumer c) { + var client = client(transportProvider, customizer); try { c.accept(client); } @@ -121,7 +122,7 @@ void verifyNotificationTimesOut(Consumer operation, String ac } void verifyCallTimesOut(Function blockingOperation, String action) { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { // This scheduler is not replaced by virtual time scheduler Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); @@ -148,7 +149,7 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpClientTransportProvider()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @@ -160,7 +161,7 @@ void testListToolsWithoutInitialization() { @Test void testListTools() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListToolsResult tools = mcpSyncClient.listTools(null); @@ -182,7 +183,7 @@ void testCallToolsWithoutInitialization() { @Test void testCallTools() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); @@ -206,7 +207,7 @@ void testPingWithoutInitialization() { @Test void testPing() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); }); @@ -220,7 +221,7 @@ void testCallToolWithoutInitialization() { @Test void testCallTool() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); @@ -235,7 +236,7 @@ void testCallTool() { @Test void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); @@ -250,7 +251,7 @@ void testRootsListChangedWithoutInitialization() { @Test void testRootsListChanged() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); }); @@ -263,7 +264,7 @@ void testListResourcesWithoutInitialization() { @Test void testListResources() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourcesResult resources = mcpSyncClient.listResources(null); @@ -281,15 +282,15 @@ void testListResources() { @Test void testClientSessionState() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThat(mcpSyncClient).isNotNull(); }); } @Test void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), + builder -> builder.roots(new Root("file:///test/path", "test-root")), mcpSyncClient -> { assertThatCode(() -> { mcpSyncClient.initialize(); @@ -300,7 +301,7 @@ void testInitializeWithRootsListProviders() { @Test void testAddRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { Root newRoot = new Root("file:///new/test/path", "new-test-root"); assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); }); @@ -308,14 +309,14 @@ void testAddRoot() { @Test void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); }); } @Test void testRemoveRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); assertThatCode(() -> { mcpSyncClient.addRoot(root); @@ -326,7 +327,7 @@ void testRemoveRoot() { @Test void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); }); @@ -340,7 +341,7 @@ void testReadResourceWithoutInitialization() { @Test void testReadResource() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourcesResult resources = mcpSyncClient.listResources(null); @@ -361,7 +362,7 @@ void testListResourceTemplatesWithoutInitialization() { @Test void testListResourceTemplates() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); @@ -372,7 +373,7 @@ void testListResourceTemplates() { // @Test void testResourceSubscription() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { ListResourcesResult resources = mcpSyncClient.listResources(null); if (!resources.resources().isEmpty()) { @@ -395,7 +396,7 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), + withClient(createMcpClientTransportProvider(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), @@ -420,7 +421,7 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { - withClient(createMcpTransport(), mcpSyncClient -> { + withClient(createMcpClientTransportProvider(), mcpSyncClient -> { mcpSyncClient.initialize(); // Test all logging levels for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { @@ -432,7 +433,7 @@ void testLoggingLevels() { @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + withClient(createMcpClientTransportProvider(), builder -> builder.requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)), client -> { assertThatCode(() -> { client.initialize(); @@ -443,8 +444,9 @@ void testLoggingConsumer() { @Test void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null")); + withClient(createMcpClientTransportProvider(), + mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index fdff4b77..6a3162b3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,14 +4,14 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; /** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. + * Tests for the {@link McpSyncClient} with * * @author Christian Tzolov */ @@ -28,8 +28,8 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { - return HttpClientSseClientTransport.builder(host).build(); + protected McpClientTransportProvider createMcpClientTransportProvider() { + return HttpClientSseClientTransportProvider.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 204cf298..de68b901 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -4,15 +4,13 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; /** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. - * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout @@ -28,8 +26,8 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { - return HttpClientSseClientTransport.builder(host).build(); + protected McpClientTransportProvider createMcpClientTransportProvider() { + return HttpClientSseClientTransportProvider.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b152..3e666128 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.MockMcpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -34,16 +34,16 @@ class McpAsyncClientResponseHandlerTests { .resources(true, true) // Enable both resources and resource templates .build(); - private static MockMcpClientTransport initializationEnabledTransport() { + private static MockMcpClientTransportProvider initializationEnabledTransport() { return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO); } - private static MockMcpClientTransport initializationEnabledTransport( + private static MockMcpClientTransportProvider initializationEnabledTransport( McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) { McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, mockServerCapabilities, mockServerInfo, "Test instructions"); - return new MockMcpClientTransport((t, message) -> { + return new MockMcpClientTransportProvider((t, message) -> { if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) { McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, r.id(), mockInitResult, null); @@ -59,8 +59,10 @@ void testSuccessfulInitialization() { .tools(false) .resources(true, true) // Enable both resources and resource templates .build(); - MockMcpClientTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); - McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(serverCapabilities, + serverInfo); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider).build(); // Verify client is not initialized initially assertThat(asyncMcpClient.isInitialized()).isFalse(); @@ -91,8 +93,8 @@ void testSuccessfulInitialization() { @Test void testToolsChangeNotificationHandling() throws JsonProcessingException { - MockMcpClientTransport transport = initializationEnabledTransport(); - + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create a list to store received tools for verification List receivedTools = new ArrayList<>(); @@ -101,7 +103,9 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { .fromRunnable(() -> receivedTools.addAll(tools)); // Create client with tools change consumer - McpAsyncClient asyncMcpClient = McpClient.async(transport).toolsChangeConsumer(toolsChangeConsumer).build(); + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) + .toolsChangeConsumer(toolsChangeConsumer) + .build(); assertThat(asyncMcpClient.initialize().block()).isNotNull(); @@ -134,9 +138,10 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { @Test void testRootsListRequestHandling() { - MockMcpClientTransport transport = initializationEnabledTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); - McpAsyncClient asyncMcpClient = McpClient.async(transport) + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) .roots(new Root("file:///test/path", "test-root")) .build(); @@ -162,7 +167,8 @@ void testRootsListRequestHandling() { @Test void testResourcesChangeNotificationHandling() { - MockMcpClientTransport transport = initializationEnabledTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create a list to store received resources for verification List receivedResources = new ArrayList<>(); @@ -172,7 +178,7 @@ void testResourcesChangeNotificationHandling() { .fromRunnable(() -> receivedResources.addAll(resources)); // Create client with resources change consumer - McpAsyncClient asyncMcpClient = McpClient.async(transport) + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) .resourcesChangeConsumer(resourcesChangeConsumer) .build(); @@ -208,7 +214,8 @@ void testResourcesChangeNotificationHandling() { @Test void testPromptsChangeNotificationHandling() { - MockMcpClientTransport transport = initializationEnabledTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create a list to store received prompts for verification List receivedPrompts = new ArrayList<>(); @@ -218,7 +225,9 @@ void testPromptsChangeNotificationHandling() { .fromRunnable(() -> receivedPrompts.addAll(prompts)); // Create client with prompts change consumer - McpAsyncClient asyncMcpClient = McpClient.async(transport).promptsChangeConsumer(promptsChangeConsumer).build(); + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) + .promptsChangeConsumer(promptsChangeConsumer) + .build(); assertThat(asyncMcpClient.initialize().block()).isNotNull(); @@ -252,7 +261,8 @@ void testPromptsChangeNotificationHandling() { @Test void testSamplingCreateMessageRequestHandling() { - MockMcpClientTransport transport = initializationEnabledTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create a test sampling handler that echoes back the input Function> samplingHandler = request -> { @@ -262,7 +272,7 @@ void testSamplingCreateMessageRequestHandling() { }; // Create client with sampling capability and handler - McpAsyncClient asyncMcpClient = McpClient.async(transport) + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) .capabilities(ClientCapabilities.builder().sampling().build()) .sampling(samplingHandler) .build(); @@ -306,10 +316,11 @@ void testSamplingCreateMessageRequestHandling() { @Test void testSamplingCreateMessageRequestHandlingWithoutCapability() { - MockMcpClientTransport transport = initializationEnabledTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create client without sampling capability - McpAsyncClient asyncMcpClient = McpClient.async(transport) + McpAsyncClient asyncMcpClient = McpClient.async(transportProvider) .capabilities(ClientCapabilities.builder().build()) // No sampling capability .build(); @@ -340,12 +351,13 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { @Test void testSamplingCreateMessageRequestHandlingWithNullHandler() { - MockMcpClientTransport transport = new MockMcpClientTransport(); + MockMcpClientTransportProvider transportProvider = initializationEnabledTransport(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); // Create client with sampling capability but null handler - assertThatThrownBy( - () -> McpClient.async(transport).capabilities(ClientCapabilities.builder().sampling().build()).build()) - .isInstanceOf(McpError.class) + assertThatThrownBy(() -> McpClient.async(transportProvider) + .capabilities(ClientCapabilities.builder().sampling().build()) + .build()).isInstanceOf(McpError.class) .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index bf473849..3b915161 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -8,6 +8,7 @@ import java.util.List; import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.MockMcpClientTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; @@ -28,8 +29,9 @@ class McpClientProtocolVersionTests { @Test void shouldUseLatestVersionByDefault() { - MockMcpClientTransport transport = new MockMcpClientTransport(); - McpAsyncClient client = McpClient.async(transport) + MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); + McpAsyncClient client = McpClient.async(transportProvider) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) .build(); @@ -61,8 +63,9 @@ void shouldUseLatestVersionByDefault() { @Test void shouldNegotiateSpecificVersion() { String oldVersion = "0.1.0"; - MockMcpClientTransport transport = new MockMcpClientTransport(); - McpAsyncClient client = McpClient.async(transport) + MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); + McpAsyncClient client = McpClient.async(transportProvider) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) .build(); @@ -94,8 +97,9 @@ void shouldNegotiateSpecificVersion() { @Test void shouldFailForUnsupportedVersion() { String unsupportedVersion = "999.999.999"; - MockMcpClientTransport transport = new MockMcpClientTransport(); - McpAsyncClient client = McpClient.async(transport) + MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); + McpAsyncClient client = McpClient.async(transportProvider) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) .build(); @@ -124,8 +128,9 @@ void shouldUseHighestVersionWhenMultipleSupported() { String middleVersion = "0.2.0"; String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; - MockMcpClientTransport transport = new MockMcpClientTransport(); - McpAsyncClient client = McpClient.async(transport) + MockMcpClientTransportProvider transportProvider = new MockMcpClientTransportProvider(); + MockMcpClientTransportProvider.MockMcpClientTransport transport = transportProvider.getTransport(); + McpAsyncClient client = McpClient.async(transportProvider) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c3908013..4f510b01 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -4,16 +4,14 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.client.transport.StdioClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Timeout; +import java.time.Duration; + /** - * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. - * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ @@ -21,7 +19,7 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override - protected McpClientTransport createMcpTransport() { + protected McpClientTransportProvider createMcpClientTransportProvider() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") @@ -33,7 +31,7 @@ protected McpClientTransport createMcpTransport() { .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); } - return new StdioClientTransport(stdioParams); + return new StdioClientTransportProvider(stdioParams); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8e75c4a3..bfef4e53 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -5,13 +5,16 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.client.transport.StdioClientTransportProvider; +import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransportProvider; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; @@ -20,8 +23,6 @@ import static org.assertj.core.api.Assertions.assertThat; /** - * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. - * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ @@ -29,7 +30,7 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override - protected McpClientTransport createMcpTransport() { + protected McpClientTransportProvider createMcpClientTransportProvider() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") @@ -41,7 +42,7 @@ protected McpClientTransport createMcpTransport() { .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); } - return new StdioClientTransport(stdioParams); + return new StdioClientTransportProvider(stdioParams); } @Test @@ -49,16 +50,20 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - McpClientTransport transport = createMcpTransport(); + McpClientTransportProvider transportProvider = createMcpClientTransportProvider(); + transportProvider.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + McpClientTransport transport = transportProvider.getSession().getTransport(); StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); - ((StdioClientTransport) transport).setStdErrorHandler(error -> { + ((StdioClientTransportProvider) transportProvider).setStdErrorHandler(error -> { receivedError.set(error); latch.countDown(); }); String errorMessage = "Test error"; - ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + ((StdioClientTransportProvider) transportProvider).getErrorSink() + .emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e5178c0e..6bbd486d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -15,6 +15,9 @@ import java.util.function.Consumer; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; @@ -35,7 +38,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** - * Tests for the {@link HttpClientSseClientTransport} class. + * Tests for the {@link HttpClientSseClientTransportProvider.HttpClientSseClientTransport} + * class. * * @author Christian Tzolov */ @@ -53,14 +57,16 @@ class HttpClientSseClientTransportTests { private TestHttpClientSseClientTransport transport; // Test class to access protected methods - static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { + static class TestHttpClientSseClientTransport implements McpClientTransport { + + McpClientTransport delegate; private final AtomicInteger inboundMessageCount = new AtomicInteger(0); private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(final String baseUri) { - super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); + public TestHttpClientSseClientTransport(McpClientTransport transport) { + this.delegate = transport; } public int getInboundMessageCount() { @@ -77,6 +83,31 @@ public void simulateMessageEvent(String jsonMessage) { inboundMessageCount.incrementAndGet(); } + @Override + public Mono connect(Function, Mono> handler) { + return delegate.connect(handler); + } + + @Override + public Mono connect() { + return delegate.connect(); + } + + @Override + public Mono closeGracefully() { + return delegate.closeGracefully(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return delegate.sendMessage(message); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return delegate.unmarshalFrom(data, typeRef); + } + } void startContainer() { @@ -88,8 +119,12 @@ void startContainer() { @BeforeEach void setUp() { startContainer(); - transport = new TestHttpClientSseClientTransport(host); - transport.connect(Function.identity()).block(); + HttpClientSseClientTransportProvider transportProvider = HttpClientSseClientTransportProvider.builder(host) + .build(); + transportProvider.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + transport = new TestHttpClientSseClientTransport(transportProvider.getSession().getTransport()); + transport.connect().block(); } @AfterEach @@ -205,14 +240,17 @@ void testGracefulShutdown() { @Test void testRetryBehavior() { // Create a client that simulates connection failures - HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host") + HttpClientSseClientTransportProvider failingTransportProvider = HttpClientSseClientTransportProvider + .builder("http://non-existent-host") .build(); - + failingTransportProvider.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + failingTransportProvider.getSession(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); // Clean up - failingTransport.closeGracefully().block(); + failingTransportProvider.closeGracefully().block(); } @Test @@ -290,12 +328,15 @@ void testCustomizeClient() { AtomicBoolean customizerCalled = new AtomicBoolean(false); // Create a transport with the customizer - HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host) .customizeClient(builder -> { builder.version(HttpClient.Version.HTTP_2); customizerCalled.set(true); }) .build(); + customizedTransport.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + customizedTransport.getSession(); // Verify the customizer was called assertThat(customizerCalled.get()).isTrue(); @@ -314,7 +355,7 @@ void testCustomizeRequest() { AtomicReference headerValue = new AtomicReference<>(); // Create a transport with the customizer - HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host) // Create a request customizer that adds a custom header .customizeRequest(builder -> { builder.header("X-Custom-Header", "test-value"); @@ -326,6 +367,9 @@ void testCustomizeRequest() { headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null)); }) .build(); + customizedTransport.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + customizedTransport.getSession(); // Verify the customizer was called assertThat(customizerCalled.get()).isTrue(); @@ -345,7 +389,7 @@ void testChainedCustomizations() { AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false); // Create a transport with both customizers chained - HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + HttpClientSseClientTransportProvider customizedTransport = HttpClientSseClientTransportProvider.builder(host) .customizeClient(builder -> { builder.connectTimeout(Duration.ofSeconds(30)); clientCustomizerCalled.set(true); @@ -355,6 +399,9 @@ void testChainedCustomizations() { requestCustomizerCalled.set(true); }) .build(); + customizedTransport.setSessionFactory( + (transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of())); + customizedTransport.getSession(); // Verify both customizers were called assertThat(clientCustomizerCalled.get()).isTrue(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 212a3c95..4e60e3f8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -5,7 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.spec.McpSchema; import org.apache.catalina.LifecycleException; @@ -54,7 +54,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + this.clientBuilder = McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT) .build()); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 135de83f..40f7be81 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -13,7 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.spec.McpError; @@ -76,9 +76,12 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + HttpClientSseClientTransportProvider transportProvider = HttpClientSseClientTransportProvider + .builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()); + .objectMapper(new ObjectMapper()) + .build(); + this.clientBuilder = McpClient.sync(transportProvider); } @AfterEach diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e..4d7a0705 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.MockMcpClientTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,13 +42,18 @@ class McpClientSessionTests { private McpClientSession session; - private MockMcpClientTransport transport; + private MockMcpClientTransportProvider.MockMcpClientTransport transport; + + private MockMcpClientTransportProvider transportProvider; @BeforeEach void setUp() { - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); + transportProvider = new MockMcpClientTransportProvider(); + transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params))))); + session = transportProvider.getSession(); + transport = transportProvider.getTransport(); + } @AfterEach @@ -139,8 +145,11 @@ void testRequestHandling() { String echoMessage = "Hello MCP!"; Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + transportProvider = new MockMcpClientTransportProvider(); + transportProvider + .setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of())); + session = transportProvider.getSession(); + transport = transportProvider.getTransport(); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -159,9 +168,11 @@ void testRequestHandling() { void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); + transportProvider = new MockMcpClientTransportProvider(); + transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))))); + session = transportProvider.getSession(); + transport = transportProvider.getTransport(); // Simulate incoming notification from the server Map notificationParams = Map.of("status", "ready");