Skip to content

Commit 24098f2

Browse files
committed
implementations for McpTransportContextExtractor<ServerRequest>
1 parent 45ec4b8 commit 24098f2

14 files changed

+115
-43
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44

55
package io.modelcontextprotocol.common;
66

7+
import io.modelcontextprotocol.spec.ProtocolVersions;
8+
79
import java.security.Principal;
810
import java.util.Collections;
11+
import java.util.HashMap;
912
import java.util.Map;
1013
import java.util.Optional;
14+
import java.util.function.Function;
1115

1216
/**
1317
* Context associated with the transport layer. It allows to add transport-level metadata
@@ -38,6 +42,25 @@ static McpTransportContext create(Map<String, Object> metadata) {
3842
return new DefaultMcpTransportContext(metadata);
3943
}
4044

45+
/**
46+
* Returns a Map with entries for MCP transport concepts such as Protocol version,
47+
* session ID and Last Event ID.
48+
* @param headers Function typically backed by an HTTP Request Headers implementation.
49+
* @return Map with entries for MCP transport concepts such as Protocol version,
50+
* session ID and Last Event ID.
51+
*/
52+
static Map<String, Object> createMetadata(Function<String, String> headers) {
53+
Map<String, Object> metadata = new HashMap<>(3);
54+
metadata.put(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION,
55+
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION))
56+
.orElse(ProtocolVersions.MCP_2025_03_26));
57+
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID))
58+
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID, v));
59+
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID))
60+
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID, v));
61+
return metadata;
62+
}
63+
4164
/**
4265
* Extract a value from the context.
4366
* @param key the key under the data is expected

mcp-core/src/main/java/io/modelcontextprotocol/server/servlet/HttpServletRequestMcpTransportContextExtractor.java

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.HashMap;
1212
import java.util.Map;
1313
import java.util.Optional;
14+
import java.util.function.Function;
1415

1516
/**
1617
* {@link McpTransportContextExtractor} implementation for {@link HttpServletRequest}.
@@ -20,23 +21,7 @@ public class HttpServletRequestMcpTransportContextExtractor
2021

2122
@Override
2223
public McpTransportContext extract(HttpServletRequest request) {
23-
return McpTransportContext.create(metadata(request));
24-
}
25-
26-
/**
27-
* @param request Servlet Request
28-
* @return Extracts Map for MCP Transport Context
29-
*/
30-
protected Map<String, Object> metadata(HttpServletRequest request) {
31-
Map<String, Object> metadata = new HashMap<>(3);
32-
metadata.put(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION,
33-
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION))
34-
.orElse(ProtocolVersions.MCP_2025_03_26));
35-
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID))
36-
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID, v));
37-
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID))
38-
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID, v));
39-
return metadata;
24+
return McpTransportContext.create(McpTransportContext.createMetadata(request::getHeader));
4025
}
4126

4227
}

mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package io.modelcontextprotocol.common;
66

7+
import java.util.HashMap;
78
import java.util.Map;
89
import java.util.function.BiFunction;
910

@@ -94,8 +95,12 @@ public class AsyncServerMcpTransportContextIntegrationTests {
9495

9596
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
9697
@Override
97-
protected Map<String, Object> metadata(HttpServletRequest r) {
98-
Map<String, Object> m = super.metadata(r);
98+
public McpTransportContext extract(HttpServletRequest request) {
99+
return McpTransportContext.create(metadata(request));
100+
}
101+
102+
private Map<String, Object> metadata(HttpServletRequest r) {
103+
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
99104
var headerValue = r.getHeader(HEADER_NAME);
100105
if (headerValue != null) {
101106
m.put("server-side-header-value", headerValue);

mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import io.modelcontextprotocol.spec.McpSchema;
2424
import jakarta.servlet.Servlet;
2525
import jakarta.servlet.http.HttpServletRequest;
26+
27+
import java.util.HashMap;
2628
import java.util.Map;
2729
import java.util.function.BiFunction;
2830
import java.util.function.Supplier;
@@ -74,8 +76,12 @@ public class SyncServerMcpTransportContextIntegrationTests {
7476

7577
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
7678
@Override
77-
protected Map<String, Object> metadata(HttpServletRequest r) {
78-
Map<String, Object> m = super.metadata(r);
79+
public McpTransportContext extract(HttpServletRequest request) {
80+
return McpTransportContext.create(metadata(request));
81+
}
82+
83+
private Map<String, Object> metadata(HttpServletRequest r) {
84+
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
7985
var headerValue = r.getHeader(HEADER_NAME);
8086
if (headerValue != null) {
8187
m.put("server-side-header-value", headerValue);

mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io.modelcontextprotocol.server;
66

77
import java.time.Duration;
8+
import java.util.HashMap;
89
import java.util.Map;
910
import java.util.stream.Stream;
1011

@@ -101,8 +102,12 @@ protected void prepareClients(int port, String mcpEndpoint) {
101102

102103
static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = new HttpServletRequestMcpTransportContextExtractor() {
103104
@Override
104-
protected Map<String, Object> metadata(HttpServletRequest r) {
105-
Map<String, Object> m = super.metadata(r);
105+
public McpTransportContext extract(HttpServletRequest request) {
106+
return McpTransportContext.create(metadata(request));
107+
}
108+
109+
Map<String, Object> metadata(HttpServletRequest r) {
110+
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
106111
m.put("important", "value");
107112
return m;
108113
}

mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
package io.modelcontextprotocol.server;
66

77
import java.time.Duration;
8+
import java.util.HashMap;
89
import java.util.Map;
910
import java.util.stream.Stream;
1011

1112
import io.modelcontextprotocol.client.McpClient;
1213
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
14+
import io.modelcontextprotocol.common.McpTransportContext;
1315
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
1416
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
1517
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
@@ -98,8 +100,12 @@ protected void prepareClients(int port, String mcpEndpoint) {
98100

99101
static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = new HttpServletRequestMcpTransportContextExtractor() {
100102
@Override
101-
protected Map<String, Object> metadata(HttpServletRequest r) {
102-
Map<String, Object> m = super.metadata(r);
103+
public McpTransportContext extract(HttpServletRequest request) {
104+
return McpTransportContext.create(metadata(request));
105+
}
106+
107+
private Map<String, Object> metadata(HttpServletRequest r) {
108+
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
103109
m.put("important", "value");
104110
return m;
105111
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.server.transport;
5+
6+
import io.modelcontextprotocol.common.McpTransportContext;
7+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
8+
import org.springframework.web.reactive.function.server.ServerRequest;
9+
10+
/**
11+
* {@link McpTransportContextExtractor} implementation for {@link ServerRequest}.
12+
*/
13+
public class McpTransportContextExtractorServerRequest implements McpTransportContextExtractor<ServerRequest> {
14+
15+
@Override
16+
public McpTransportContext extract(ServerRequest request) {
17+
return McpTransportContext
18+
.create(McpTransportContext.createMetadata(headerName -> request.headers().firstHeader(headerName)));
19+
}
20+
21+
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ public static class Builder {
418418

419419
private Duration keepAliveInterval;
420420

421-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
422-
serverRequest) -> McpTransportContext.EMPTY;
421+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
423422

424423
/**
425424
* Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP
@@ -507,7 +506,8 @@ public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> cont
507506
public WebFluxSseServerTransportProvider build() {
508507
Assert.notNull(messageEndpoint, "Message endpoint must be set");
509508
return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
510-
baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor);
509+
baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
510+
contextExtractor == null ? new McpTransportContextExtractorServerRequest() : contextExtractor);
511511
}
512512

513513
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ public static class Builder {
157157

158158
private String mcpEndpoint = "/mcp";
159159

160-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
161-
serverRequest) -> McpTransportContext.EMPTY;
160+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
162161

163162
private Builder() {
164163
// used by a static method
@@ -214,7 +213,8 @@ public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> cont
214213
public WebFluxStatelessServerTransport build() {
215214
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
216215
return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
217-
mcpEndpoint, contextExtractor);
216+
mcpEndpoint,
217+
contextExtractor == null ? new McpTransportContextExtractorServerRequest() : contextExtractor);
218218
}
219219

220220
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,7 @@ public static class Builder {
403403

404404
private String mcpEndpoint = "/mcp";
405405

406-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
407-
serverRequest) -> McpTransportContext.EMPTY;
406+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
408407

409408
private boolean disallowDelete;
410409

@@ -486,7 +485,8 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
486485
public WebFluxStreamableServerTransportProvider build() {
487486
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
488487
return new WebFluxStreamableServerTransportProvider(
489-
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor,
488+
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint,
489+
contextExtractor == null ? new McpTransportContextExtractorServerRequest() : contextExtractor,
490490
disallowDelete, keepAliveInterval);
491491
}
492492

0 commit comments

Comments
 (0)