Skip to content

Support Progress Flow #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));

// Utility Progress Notification
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumersFinal = new ArrayList<>();
progressConsumersFinal
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification)));
if (!Utils.isEmpty(features.progressConsumers())) {
progressConsumersFinal.addAll(features.progressConsumers());
}
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
asyncProgressNotificationHandler(progressConsumersFinal));

this.transport.setExceptionHandler(this::handleException);
this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers,
notificationHandlers);
Expand Down Expand Up @@ -985,6 +995,20 @@ private NotificationHandler asyncLoggingNotificationHandler(
};
}

private NotificationHandler asyncProgressNotificationHandler(
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {

return params -> {
McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params,
new TypeReference<McpSchema.ProgressNotification>() {
});

return Flux.fromIterable(progressConsumers)
.flatMap(consumer -> consumer.apply(progressNotification))
.then();
};
}

/**
* Sets the minimum logging level for messages received from the server. The client
* will only receive log messages at or above the specified severity level.
Expand Down
41 changes: 38 additions & 3 deletions mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class SyncSpec {

private final List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers = new ArrayList<>();

private final List<Consumer<McpSchema.ProgressNotification>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, CreateMessageResult> samplingHandler;

private Function<ElicitRequest, ElicitResult> elicitationHandler;
Expand Down Expand Up @@ -377,6 +379,36 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
return this;
}

/**
* Adds a consumer to be notified of progress notifications from the server. This
* allows the client to track long-running operations and provide feedback to
* users.
* @param progressConsumer A consumer that receives progress notifications. Must
* not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumer(Consumer<McpSchema.ProgressNotification> progressConsumer) {
Assert.notNull(progressConsumer, "Progress consumer must not be null");
this.progressConsumers.add(progressConsumer);
return this;
}

/**
* Adds a multiple consumers to be notified of progress notifications from the
* server. This allows the client to track long-running operations and provide
* feedback to users.
* @param progressConsumers A list of consumers that receives progress
* notifications. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>> progressConsumers) {
Assert.notNull(progressConsumers, "Progress consumers must not be null");
this.progressConsumers.addAll(progressConsumers);
return this;
}

/**
* Create an instance of {@link McpSyncClient} with the provided configurations or
* sensible defaults.
Expand All @@ -385,7 +417,8 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, this.elicitationHandler);
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

Expand Down Expand Up @@ -435,6 +468,8 @@ class AsyncSpec {

private final List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers = new ArrayList<>();

private final List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler;

private Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler;
Expand Down Expand Up @@ -663,8 +698,8 @@ public McpAsyncClient build() {
return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler,
this.elicitationHandler));
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class McpClientFeatures {
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -68,6 +69,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<List<McpSchema.ResourceContents>, Mono<Void>>> resourcesUpdateConsumers,
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {

Expand All @@ -79,6 +81,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progressconsumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -89,6 +92,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<List<McpSchema.ResourceContents>, Mono<Void>>> resourcesUpdateConsumers,
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {

Expand All @@ -106,6 +110,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of();
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
}
Expand Down Expand Up @@ -149,6 +154,12 @@ public static Async fromSync(Sync syncSpec) {
.subscribeOn(Schedulers.boundedElastic()));
}

List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();
for (Consumer<McpSchema.ProgressNotification> consumer : syncSpec.progressConsumers()) {
progressConsumers.add(p -> Mono.<Void>fromRunnable(() -> consumer.accept(p))
.subscribeOn(Schedulers.boundedElastic()));
}

Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler = r -> Mono
.fromCallable(() -> syncSpec.samplingHandler().apply(r))
.subscribeOn(Schedulers.boundedElastic());
Expand All @@ -159,7 +170,7 @@ public static Async fromSync(Sync syncSpec) {

return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers,
loggingConsumers, samplingHandler, elicitationHandler);
loggingConsumers, progressConsumers, samplingHandler, elicitationHandler);
}
}

Expand All @@ -174,6 +185,7 @@ public static Async fromSync(Sync syncSpec) {
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -183,6 +195,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
List<Consumer<List<McpSchema.ResourceContents>>> resourcesUpdateConsumers,
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {

Expand All @@ -196,6 +209,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
* @param resourcesUpdateConsumers the resource update consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -205,6 +219,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
List<Consumer<List<McpSchema.ResourceContents>>> resourcesUpdateConsumers,
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {

Expand All @@ -222,6 +237,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of();
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ private McpServerSession.RequestHandler<CallToolResult> toolsCallRequestHandler(
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
}

return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments()))
return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest))
.orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name())));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ public Mono<McpSchema.ListRootsResult> listRoots(String cursor) {
LIST_ROOTS_RESULT_TYPE_REF);
}

public Mono<Void> notification(String method, Object params) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too generic. We would need something like progress(ProgressNotificaiton pn) instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought restricting the method to only “progress” would block users until the SDK releases a new protocol adaptation, and it did in some cases: loggingNotification(LoggingMessageNotification loggingMessageNotification).

I agree there is also progress(ProgressNotification pn) should be exists.

Also, there might have users that who want to send custom notifications to.

if (method == null || method.isEmpty()) {
return Mono.error(new McpError("Method must not be null or empty"));
}
if (params == null) {
return Mono.error(new McpError("Params must not be null"));
}
return this.session.sendNotification(method, params);
}

/**
* Send a logging message notification to the client. Messages below the current
* minimum logging level will be filtered out.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi
* Example usage: <pre>{@code
* .tool(
* new Tool("calculator", "Performs calculations", schema),
* (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
* (exchange, request) -> Mono.fromSupplier(() -> calculate(request))
* .map(result -> new CallToolResult("Result: " + result))
* )
* }</pre>
Expand All @@ -323,7 +323,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi
* @throws IllegalArgumentException if tool or handler is null
*/
public AsyncSpecification tool(McpSchema.Tool tool,
BiFunction<McpAsyncServerExchange, Map<String, Object>, Mono<CallToolResult>> handler) {
BiFunction<McpAsyncServerExchange, McpSchema.CallToolRequest, Mono<CallToolResult>> handler) {
Assert.notNull(tool, "Tool must not be null");
Assert.notNull(handler, "Handler must not be null");

Expand Down Expand Up @@ -801,7 +801,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil
* Example usage: <pre>{@code
* .tool(
* new Tool("calculator", "Performs calculations", schema),
* (exchange, args) -> new CallToolResult("Result: " + calculate(args))
* (exchange, request) -> new CallToolResult("Result: " + calculate(request))
* )
* }</pre>
* @param tool The tool definition including name, description, and schema. Must
Expand All @@ -814,7 +814,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil
* @throws IllegalArgumentException if tool or handler is null
*/
public SyncSpecification tool(McpSchema.Tool tool,
BiFunction<McpSyncServerExchange, Map<String, Object>, McpSchema.CallToolResult> handler) {
BiFunction<McpSyncServerExchange, McpSchema.CallToolRequest, McpSchema.CallToolResult> handler) {
Assert.notNull(tool, "Tool must not be null");
Assert.notNull(handler, "Handler must not be null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
* .required("expression")
* .property("expression", JsonSchemaType.STRING)
* ),
* (exchange, args) -> {
* String expr = (String) args.get("expression");
* (exchange, request) -> {
* String expr = (String) request.arguments().get("expression");
* return Mono.fromSupplier(() -> evaluate(expr))
* .map(result -> new CallToolResult("Result: " + result));
* }
Expand All @@ -237,7 +237,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
* connected client. The second arguments is a map of tool arguments.
*/
public record AsyncToolSpecification(McpSchema.Tool tool,
BiFunction<McpAsyncServerExchange, Map<String, Object>, Mono<McpSchema.CallToolResult>> call) {
BiFunction<McpAsyncServerExchange, McpSchema.CallToolRequest, Mono<McpSchema.CallToolResult>> call) {

static AsyncToolSpecification fromSync(SyncToolSpecification tool) {
// FIXME: This is temporary, proper validation should be implemented
Expand Down Expand Up @@ -413,7 +413,7 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet
* client. The second arguments is a map of arguments passed to the tool.
*/
public record SyncToolSpecification(McpSchema.Tool tool,
BiFunction<McpSyncServerExchange, Map<String, Object>, McpSchema.CallToolResult> call) {
BiFunction<McpSyncServerExchange, McpSchema.CallToolRequest, McpSchema.CallToolResult> call) {
}

/**
Expand Down
30 changes: 26 additions & 4 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.util.annotation.Nullable;

/**
* Based on the <a href="http://www.jsonrpc.org/specification">JSON-RPC 2.0
Expand Down Expand Up @@ -57,6 +58,8 @@ private McpSchema() {

public static final String METHOD_PING = "ping";

public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress";

// Tool Methods
public static final String METHOD_TOOLS_LIST = "tools/list";

Expand Down Expand Up @@ -867,15 +870,22 @@ private static JsonSchema parseSchema(String schema) {
* tools/list.
* @param arguments Arguments to pass to the tool. These must conform to the tool's
* input schema.
* @param _meta Optional metadata about the request. This can include additional
* information like `progressToken`
*/
@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
public record CallToolRequest(// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("arguments") Map<String, Object> arguments) implements Request {
@JsonProperty("arguments") Map<String, Object> arguments,
@Nullable @JsonProperty("_meta") Map<String, Object> _meta) implements Request {

public CallToolRequest(String name, String jsonArguments) {
this(name, parseJsonArguments(jsonArguments));
this(name, parseJsonArguments(jsonArguments), null);
}

public CallToolRequest(String name, Map<String, Object> arguments) {
this(name, arguments, null);
}

private static Map<String, Object> parseJsonArguments(String jsonArguments) {
Expand Down Expand Up @@ -1309,11 +1319,23 @@ public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) {
// ---------------------------
// Progress and Logging
// ---------------------------

/**
* The Model Context Protocol (MCP) supports optional progress tracking for
* long-running operations through notification messages. Either side can send
* progress notifications to provide updates about operation status.
*
* @param progressToken The original progress token
* @param progress The current progress value so far
* @param total An optional “total” value
* @param message An optional “message” value
*/
@JsonIgnoreProperties(ignoreUnknown = true)
public record ProgressNotification(// @formatter:off
@JsonProperty("progressToken") String progressToken,
@JsonProperty("progress") double progress,
@JsonProperty("total") Double total) {
@JsonProperty("progress") Double progress,
@JsonProperty("total") Double total,
@JsonProperty("message") String message) {
}// @formatter:on

/**
Expand Down
Loading