@@ -72,11 +72,12 @@ async def main():
7272import warnings
7373from collections .abc import AsyncIterator , Awaitable , Callable , Iterable
7474from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
75- from typing import Any , Generic , TypeVar
75+ from typing import Any , Generic
7676
7777import anyio
7878from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7979from pydantic import AnyUrl
80+ from typing_extensions import TypeVar
8081
8182import mcp .types as types
8283from mcp .server .lowlevel .helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ async def main():
8586from mcp .server .stdio import stdio_server as stdio_server
8687from mcp .shared .context import RequestContext
8788from mcp .shared .exceptions import McpError
88- from mcp .shared .message import SessionMessage
89+ from mcp .shared .message import ServerMessageMetadata , SessionMessage
8990from mcp .shared .session import RequestResponder
9091
9192logger = logging .getLogger (__name__ )
9293
9394LifespanResultT = TypeVar ("LifespanResultT" )
95+ RequestT = TypeVar ("RequestT" , default = Any )
9496
9597# This will be properly typed in each Server instance's context
96- request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any ]] = (
98+ request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any , Any ]] = (
9799 contextvars .ContextVar ("request_ctx" )
98100)
99101
@@ -111,7 +113,7 @@ def __init__(
111113
112114
113115@asynccontextmanager
114- async def lifespan (server : Server [LifespanResultT ]) -> AsyncIterator [object ]:
116+ async def lifespan (server : Server [LifespanResultT , RequestT ]) -> AsyncIterator [object ]:
115117 """Default lifespan context manager that does nothing.
116118
117119 Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
123125 yield {}
124126
125127
126- class Server (Generic [LifespanResultT ]):
128+ class Server (Generic [LifespanResultT , RequestT ]):
127129 def __init__ (
128130 self ,
129131 name : str ,
130132 version : str | None = None ,
131133 instructions : str | None = None ,
132134 lifespan : Callable [
133- [Server [LifespanResultT ]], AbstractAsyncContextManager [LifespanResultT ]
135+ [Server [LifespanResultT , RequestT ]],
136+ AbstractAsyncContextManager [LifespanResultT ],
134137 ] = lifespan ,
135138 ):
136139 self .name = name
@@ -215,7 +218,9 @@ def get_capabilities(
215218 )
216219
217220 @property
218- def request_context (self ) -> RequestContext [ServerSession , LifespanResultT ]:
221+ def request_context (
222+ self ,
223+ ) -> RequestContext [ServerSession , LifespanResultT , RequestT ]:
219224 """If called outside of a request context, this will raise a LookupError."""
220225 return request_ctx .get ()
221226
@@ -555,6 +560,13 @@ async def _handle_request(
555560
556561 token = None
557562 try :
563+ # Extract request context from message metadata
564+ request_data = None
565+ if message .message_metadata is not None and isinstance (
566+ message .message_metadata , ServerMessageMetadata
567+ ):
568+ request_data = message .message_metadata .request_context
569+
558570 # Set our global state that can be retrieved via
559571 # app.get_request_context()
560572 token = request_ctx .set (
@@ -563,6 +575,7 @@ async def _handle_request(
563575 message .request_meta ,
564576 session ,
565577 lifespan_context ,
578+ request = request_data ,
566579 )
567580 )
568581 response = await handler (req )
0 commit comments