Skip to content

Commit 34d0358

Browse files
committed
feat(client): small enhancements + adds Batch to McpSchema to simplify StreamableHttpClientTransport
1 parent ce27db5 commit 34d0358

File tree

2 files changed

+78
-70
lines changed

2 files changed

+78
-70
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java

+51-69
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import java.net.http.HttpResponse;
1515
import java.nio.charset.StandardCharsets;
1616
import java.time.Duration;
17+
import java.util.ArrayList;
1718
import java.util.List;
1819
import java.util.concurrent.atomic.AtomicBoolean;
1920
import java.util.concurrent.atomic.AtomicReference;
@@ -192,6 +193,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
192193
.doOnTerminate(() -> state.set(TransportState.CLOSED))
193194
.onErrorResume(e -> {
194195
LOGGER.error("Streamable transport connection error", e);
196+
state.set(TransportState.DISCONNECTED);
195197
return Mono.error(e);
196198
}));
197199
}
@@ -204,43 +206,14 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
204206
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message,
205207
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
206208
if (fallbackToSse.get()) {
207-
return sseClientTransport.sendMessage(message);
209+
return fallbackToSse(message);
208210
}
209211

210212
if (state.get() == TransportState.CLOSED) {
211213
return Mono.empty();
212214
}
213215

214-
return sentPost(message, handler).onErrorResume(e -> {
215-
LOGGER.error("Streamable transport sendMessage error", e);
216-
return Mono.error(e);
217-
});
218-
}
219-
220-
/**
221-
* Sends a list of messages to the server.
222-
* @param messages the list of messages to send
223-
* @return a Mono that completes when all messages have been sent
224-
*/
225-
public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages,
226-
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
227-
if (fallbackToSse.get()) {
228-
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
229-
}
230-
231-
if (state.get() == TransportState.CLOSED) {
232-
return Mono.empty();
233-
}
234-
235-
return sentPost(messages, handler).onErrorResume(e -> {
236-
LOGGER.error("Streamable transport sendMessages error", e);
237-
return Mono.error(e);
238-
});
239-
}
240-
241-
private Mono<Void> sentPost(final Object msg,
242-
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
243-
return serializeJson(msg).flatMap(json -> {
216+
return serializeJson(message).flatMap(json -> {
244217
final HttpRequest request = requestBuilder.copy()
245218
.POST(HttpRequest.BodyPublishers.ofString(json))
246219
.uri(uri)
@@ -256,15 +229,7 @@ private Mono<Void> sentPost(final Object msg,
256229
if (response.statusCode() == 405 || response.statusCode() == 404) {
257230
LOGGER.warn("Operation not allowed, falling back to SSE");
258231
fallbackToSse.set(true);
259-
if (msg instanceof McpSchema.JSONRPCMessage message) {
260-
return sseClientTransport.sendMessage(message);
261-
}
262-
263-
if (msg instanceof List<?> list) {
264-
@SuppressWarnings("unchecked")
265-
final List<McpSchema.JSONRPCMessage> messages = (List<McpSchema.JSONRPCMessage>) list;
266-
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
267-
}
232+
return fallbackToSse(message);
268233
}
269234

270235
if (response.statusCode() >= 400) {
@@ -274,18 +239,28 @@ private Mono<Void> sentPost(final Object msg,
274239

275240
return handleStreamingResponse(response, handler);
276241
});
242+
}).onErrorResume(e -> {
243+
LOGGER.error("Streamable transport sendMessages error", e);
244+
return Mono.error(e);
277245
});
278246

279247
}
280248

281-
private Mono<String> serializeJson(final Object input) {
249+
private Mono<Void> fallbackToSse(final McpSchema.JSONRPCMessage msg) {
250+
if (msg instanceof McpSchema.JSONRPCBatchRequest batch) {
251+
return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then();
252+
}
253+
254+
if (msg instanceof McpSchema.JSONRPCBatchResponse batch) {
255+
return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then();
256+
}
257+
258+
return sseClientTransport.sendMessage(msg);
259+
}
260+
261+
private Mono<String> serializeJson(final McpSchema.JSONRPCMessage msg) {
282262
try {
283-
if (input instanceof McpSchema.JSONRPCMessage || input instanceof List) {
284-
return Mono.just(objectMapper.writeValueAsString(input));
285-
}
286-
else {
287-
return Mono.error(new IllegalArgumentException("Unsupported message type for serialization"));
288-
}
263+
return Mono.just(objectMapper.writeValueAsString(msg));
289264
}
290265
catch (IOException e) {
291266
LOGGER.error("Error serializing JSON-RPC message", e);
@@ -313,9 +288,15 @@ else if (contentType.contains("application/json")) {
313288
private Mono<Void> handleSingleJson(final HttpResponse<InputStream> response,
314289
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
315290
return Mono.fromCallable(() -> {
291+
try {
316292
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
317293
new String(response.body().readAllBytes(), StandardCharsets.UTF_8));
318294
return handler.apply(Mono.just(msg));
295+
}
296+
catch (IOException e) {
297+
LOGGER.error("Error processing JSON response", e);
298+
return Mono.error(e);
299+
}
319300
}).flatMap(Function.identity()).then();
320301
}
321302

@@ -328,7 +309,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
328309
}
329310
catch (IOException e) {
330311
LOGGER.error("Error processing JSON line", e);
331-
return Mono.empty();
312+
return Mono.error(e);
332313
}
333314
}).then();
334315
}
@@ -347,7 +328,7 @@ private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
347328
if (line.startsWith("event: "))
348329
event = line.substring(7).trim();
349330
else if (line.startsWith("data: "))
350-
data += line.substring(6).trim() + "\n";
331+
data += line.substring(6) + "\n";
351332
else if (line.startsWith("id: "))
352333
id = line.substring(4).trim();
353334
}
@@ -359,34 +340,35 @@ else if (line.startsWith("id: "))
359340
return new FlowSseClient.SseEvent(event, data, id);
360341
})
361342
.filter(sseEvent -> "message".equals(sseEvent.type()))
362-
.doOnNext(sseEvent -> {
363-
lastEventId.set(sseEvent.id());
343+
.concatMap(sseEvent -> {
344+
String rawData = sseEvent.data().trim();
364345
try {
365-
String rawData = sseEvent.data().trim();
366346
JsonNode node = objectMapper.readTree(rawData);
367-
347+
List<McpSchema.JSONRPCMessage> messages = new ArrayList<>();
368348
if (node.isArray()) {
369349
for (JsonNode item : node) {
370-
String rawMessage = objectMapper.writeValueAsString(item);
371-
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
372-
rawMessage);
373-
handler.apply(Mono.just(msg)).subscribe();
350+
messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, item.toString()));
374351
}
375-
}
376-
else if (node.isObject()) {
377-
String rawMessage = objectMapper.writeValueAsString(node);
378-
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, rawMessage);
379-
handler.apply(Mono.just(msg)).subscribe();
380-
}
381-
else {
352+
} else if (node.isObject()) {
353+
messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, node.toString()));
354+
} else {
355+
String warning = "Unexpected JSON in SSE data: " + rawData;
382356
LOGGER.warn("Unexpected JSON in SSE data: {}", rawData);
357+
return Mono.error(new IllegalArgumentException(warning));
383358
}
359+
360+
return Flux.fromIterable(messages)
361+
.concatMap(msg -> handler.apply(Mono.just(msg)))
362+
.then(Mono.fromRunnable(() -> {
363+
if (!sseEvent.id().isEmpty()) {
364+
lastEventId.set(sseEvent.id());
365+
}
366+
}));
367+
} catch (IOException e) {
368+
LOGGER.error("Error parsing SSE JSON: {}", rawData, e);
369+
return Mono.error(e);
384370
}
385-
catch (IOException e) {
386-
LOGGER.error("Error processing SSE event: {}", sseEvent.data(), e);
387-
}
388-
})
389-
.then();
371+
}).then();
390372
}
391373

392374
@Override

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.List;
1111
import java.util.Map;
1212

13+
import com.fasterxml.jackson.annotation.JsonIgnore;
1314
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1415
import com.fasterxml.jackson.annotation.JsonInclude;
1516
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -171,12 +172,37 @@ else if (map.containsKey("result") || map.containsKey("error")) {
171172
// ---------------------------
172173
// JSON-RPC Message Types
173174
// ---------------------------
174-
public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse {
175+
public sealed interface JSONRPCMessage
176+
permits JSONRPCBatchRequest, JSONRPCBatchResponse, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse {
175177

176178
String jsonrpc();
177179

178180
}
179181

182+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
183+
@JsonIgnoreProperties(ignoreUnknown = true)
184+
public record JSONRPCBatchRequest( // @formatter:off
185+
@JsonProperty("items") List<JSONRPCMessage> items) implements JSONRPCMessage {
186+
187+
@Override
188+
@JsonIgnore
189+
public String jsonrpc() {
190+
return JSONRPC_VERSION;
191+
}
192+
} // @formatter:on
193+
194+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
195+
@JsonIgnoreProperties(ignoreUnknown = true)
196+
public record JSONRPCBatchResponse( // @formatter:off
197+
@JsonProperty("items") List<JSONRPCMessage> items) implements JSONRPCMessage {
198+
199+
@Override
200+
@JsonIgnore
201+
public String jsonrpc() {
202+
return JSONRPC_VERSION;
203+
}
204+
} // @formatter:on
205+
180206
@JsonInclude(JsonInclude.Include.NON_ABSENT)
181207
@JsonIgnoreProperties(ignoreUnknown = true)
182208
public record JSONRPCRequest( // @formatter:off

0 commit comments

Comments
 (0)