Skip to content

feat(server): tool call support context #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,16 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return request.bodyToMono(String.class).flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
// TODO: instead of signalling the error, just respond with 200 OK
// - the error is signalled on the SSE connection
// return ServerResponse.ok().build();
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
return session.handle(request, message)
.flatMap(response -> ServerResponse.ok().build())
.onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
// TODO: instead of signalling the error, just respond with 200 OK
// - the error is signalled on the SSE connection
// return ServerResponse.ok().build();
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
}
catch (IllegalArgumentException | IOException e) {
logger.error("Failed to deserialize message: {}", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,4 +802,4 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
mcpServer.close();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ private ServerResponse handleMessage(ServerRequest request) {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);

// Process the message through the session's handle method
session.handle(message).block(); // Block for WebMVC compatibility
session.handle(request, message).block(); // Block for WebMVC compatibility

return ServerResponse.ok().build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void testAddTool() {
.build();

StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool,
(exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))))
(exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))))
.verifyComplete();

assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException();
Expand All @@ -123,12 +123,12 @@ void testAddDuplicateTool() {
var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.tool(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))
.build();

StepVerifier
.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool,
(exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))))
(exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))))
.verifyErrorSatisfies(error -> {
assertThat(error).isInstanceOf(McpError.class)
.hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists");
Expand All @@ -144,7 +144,7 @@ void testRemoveTool() {
var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))
.build();

StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete();
Expand Down Expand Up @@ -173,7 +173,7 @@ void testNotifyToolsListChanged() {
var mcpAsyncServer = McpServer.async(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false)))
.tool(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))
.build();

StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void testAddTool() {

Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema);
assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool,
(exchange, args) -> new CallToolResult(List.of(), false))))
(exchange, request) -> new CallToolResult(List.of(), false))))
.doesNotThrowAnyException();

assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException();
Expand All @@ -130,11 +130,11 @@ void testAddDuplicateTool() {
var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false))
.tool(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false))
.build();

assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool,
(exchange, args) -> new CallToolResult(List.of(), false))))
(exchange, request) -> new CallToolResult(List.of(), false))))
.isInstanceOf(McpError.class)
.hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists");

Expand All @@ -148,7 +148,7 @@ void testRemoveTool() {
var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tool(tool, (exchange, args) -> new CallToolResult(List.of(), false))
.tool(tool, (exchange, request) -> new CallToolResult(List.of(), false))
.build();

assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpRequest;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
Expand Down Expand Up @@ -292,7 +293,7 @@ private static class AsyncServerImpl extends McpAsyncServer {
// Initialize request handlers for standard MCP methods

// Ping MUST respond with an empty data, but not NULL response.
requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of()));
requestHandlers.put(McpSchema.METHOD_PING, (exchange, request) -> Mono.just(Map.of()));

// Add tools API handlers if the tool capability is enabled
if (this.serverCapabilities.tools() != null) {
Expand Down Expand Up @@ -472,17 +473,17 @@ public Mono<Void> notifyToolsListChanged() {
}

private McpServerSession.RequestHandler<McpSchema.ListToolsResult> toolsListRequestHandler() {
return (exchange, params) -> {
return (exchange, request) -> {
List<Tool> tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList();

return Mono.just(new McpSchema.ListToolsResult(tools, null));
};
}

private McpServerSession.RequestHandler<CallToolResult> toolsCallRequestHandler() {
return (exchange, params) -> {
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params,
new TypeReference<McpSchema.CallToolRequest>() {
return (exchange, request) -> {
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(request.params(),
new TypeReference<>() {
});

Optional<McpServerFeatures.AsyncToolSpecification> toolSpecification = this.tools.stream()
Expand All @@ -493,7 +494,9 @@ private McpServerSession.RequestHandler<CallToolResult> toolsCallRequestHandler(
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
}

return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments()))
return toolSpecification
.map(tool -> tool.call()
.apply(exchange, new McpRequest(callToolRequest.arguments(), request.context())))
.orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name())));
};
}
Expand Down Expand Up @@ -553,7 +556,7 @@ public Mono<Void> notifyResourcesListChanged() {
}

private McpServerSession.RequestHandler<McpSchema.ListResourcesResult> resourcesListRequestHandler() {
return (exchange, params) -> {
return (exchange, request) -> {
var resourceList = this.resources.values()
.stream()
.map(McpServerFeatures.AsyncResourceSpecification::resource)
Expand All @@ -563,14 +566,14 @@ private McpServerSession.RequestHandler<McpSchema.ListResourcesResult> resources
}

private McpServerSession.RequestHandler<McpSchema.ListResourceTemplatesResult> resourceTemplateListRequestHandler() {
return (exchange, params) -> Mono
return (exchange, request) -> Mono
.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null));

}

private McpServerSession.RequestHandler<McpSchema.ReadResourceResult> resourcesReadRequestHandler() {
return (exchange, params) -> {
McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params,
return (exchange, request) -> {
McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(request.params(),
new TypeReference<McpSchema.ReadResourceRequest>() {
});
var resourceUri = resourceRequest.uri();
Expand Down Expand Up @@ -646,7 +649,7 @@ public Mono<Void> notifyPromptsListChanged() {
}

private McpServerSession.RequestHandler<McpSchema.ListPromptsResult> promptsListRequestHandler() {
return (exchange, params) -> {
return (exchange, request) -> {
// TODO: Implement pagination
// McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
// new TypeReference<McpSchema.PaginatedRequest>() {
Expand All @@ -662,8 +665,8 @@ private McpServerSession.RequestHandler<McpSchema.ListPromptsResult> promptsList
}

private McpServerSession.RequestHandler<McpSchema.GetPromptResult> promptsGetRequestHandler() {
return (exchange, params) -> {
McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params,
return (exchange, request) -> {
McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(request.params(),
new TypeReference<McpSchema.GetPromptRequest>() {
});

Expand Down Expand Up @@ -697,10 +700,10 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
}

private McpServerSession.RequestHandler<Object> setLoggerRequestHandler() {
return (exchange, params) -> {
return (exchange, request) -> {
return Mono.defer(() -> {

SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params,
SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(request.params(),
new TypeReference<SetLevelRequest>() {
});

Expand All @@ -716,8 +719,8 @@ private McpServerSession.RequestHandler<Object> setLoggerRequestHandler() {
}

private McpServerSession.RequestHandler<McpSchema.CompleteResult> completionCompleteRequestHandler() {
return (exchange, params) -> {
McpSchema.CompleteRequest request = parseCompletionParams(params);
return (exchange, req) -> {
McpSchema.CompleteRequest request = parseCompletionParams(req.params());

if (request.ref() == null) {
return Mono.error(new McpError("ref must not be null"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import java.util.function.BiFunction;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpRequest;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.RequestContext;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -306,7 +308,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi
* @throws IllegalArgumentException if tool or handler is null
*/
public AsyncSpecification tool(McpSchema.Tool tool,
BiFunction<McpAsyncServerExchange, Map<String, Object>, Mono<CallToolResult>> handler) {
BiFunction<McpAsyncServerExchange, McpRequest, Mono<CallToolResult>> handler) {
Assert.notNull(tool, "Tool must not be null");
Assert.notNull(handler, "Handler must not be null");

Expand Down Expand Up @@ -751,7 +753,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil
* @throws IllegalArgumentException if tool or handler is null
*/
public SyncSpecification tool(McpSchema.Tool tool,
BiFunction<McpSyncServerExchange, Map<String, Object>, McpSchema.CallToolResult> handler) {
BiFunction<McpSyncServerExchange, McpRequest, McpSchema.CallToolResult> handler) {
Assert.notNull(tool, "Tool must not be null");
Assert.notNull(handler, "Handler must not be null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@

package io.modelcontextprotocol.server;

import io.modelcontextprotocol.spec.McpRequest;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.RequestContext;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;

import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/**
* MCP server features specification that a particular server can choose to support.
*
Expand Down Expand Up @@ -73,9 +75,9 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
: new McpSchema.ServerCapabilities(null, // completions
null, // experimental
new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable
// logging
// by
// default
// logging
// by
// default
!Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null,
!Utils.isEmpty(resources)
? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null,
Expand Down Expand Up @@ -181,9 +183,9 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
: new McpSchema.ServerCapabilities(null, // completions
null, // experimental
new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable
// logging
// by
// default
// logging
// by
// default
!Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null,
!Utils.isEmpty(resources)
? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null,
Expand Down Expand Up @@ -237,16 +239,16 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
* connected client. The second arguments is a map of tool arguments.
*/
public record AsyncToolSpecification(McpSchema.Tool tool,
BiFunction<McpAsyncServerExchange, Map<String, Object>, Mono<McpSchema.CallToolResult>> call) {
BiFunction<McpAsyncServerExchange, McpRequest, Mono<McpSchema.CallToolResult>> call) {

static AsyncToolSpecification fromSync(SyncToolSpecification tool) {
// FIXME: This is temporary, proper validation should be implemented
if (tool == null) {
return null;
}
return new AsyncToolSpecification(tool.tool(),
(exchange, map) -> Mono
.fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map))
(exchange, request) -> Mono
.fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), request))
.subscribeOn(Schedulers.boundedElastic()));
}
}
Expand Down Expand Up @@ -413,7 +415,8 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet
* client. The second arguments is a map of arguments passed to the tool.
*/
public record SyncToolSpecification(McpSchema.Tool tool,
BiFunction<McpSyncServerExchange, Map<String, Object>, McpSchema.CallToolResult> call) {
BiFunction<McpSyncServerExchange, McpRequest, McpSchema.CallToolResult> call) {

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());

// Process the message through the session's handle method
session.handle(message).block(); // Block for Servlet compatibility
session.handle(request, message).block(); // Block for Servlet compatibility

response.setStatus(HttpServletResponse.SC_OK);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ private void initProcessing() {
}

private void handleIncomingMessages() {
this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> {
this.inboundSink.asFlux().flatMap(message -> session.handle(null, message)).doOnTerminate(() -> {
// The outbound processing will dispose its scheduler upon completion
this.outboundSink.tryEmitComplete();
this.inboundScheduler.dispose();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.modelcontextprotocol.spec;

/**
* @param params the parameters of the request.
* @param context the request context
* @author taobaorun
*/
public record McpRequest(Object params, RequestContext context) {
}
Loading