33from contextlib import AsyncExitStack
44from datetime import timedelta
55from types import TracebackType
6- from typing import Any , Generic , TypeVar
6+ from typing import Any , Generic , Protocol , TypeVar
77
88import anyio
99import httpx
2424 JSONRPCNotification ,
2525 JSONRPCRequest ,
2626 JSONRPCResponse ,
27+ ProgressNotification ,
2728 RequestParams ,
2829 ServerNotification ,
2930 ServerRequest ,
4243RequestId = str | int
4344
4445
46+ class ProgressFnT (Protocol ):
47+ """Protocol for progress notification callbacks."""
48+
49+ async def __call__ (
50+ self , progress : float , total : float | None , message : str | None
51+ ) -> None : ...
52+
53+
4554class RequestResponder (Generic [ReceiveRequestT , SendResultT ]):
4655 """Handles responding to MCP requests and manages request lifecycle.
4756
@@ -169,6 +178,7 @@ class BaseSession(
169178 ]
170179 _request_id : int
171180 _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
181+ _progress_callbacks : dict [RequestId , ProgressFnT ]
172182
173183 def __init__ (
174184 self ,
@@ -187,6 +197,7 @@ def __init__(
187197 self ._receive_notification_type = receive_notification_type
188198 self ._session_read_timeout_seconds = read_timeout_seconds
189199 self ._in_flight = {}
200+ self ._progress_callbacks = {}
190201 self ._exit_stack = AsyncExitStack ()
191202
192203 async def __aenter__ (self ) -> Self :
@@ -214,6 +225,7 @@ async def send_request(
214225 result_type : type [ReceiveResultT ],
215226 request_read_timeout_seconds : timedelta | None = None ,
216227 metadata : MessageMetadata = None ,
228+ progress_callback : ProgressFnT | None = None ,
217229 ) -> ReceiveResultT :
218230 """
219231 Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +243,25 @@ async def send_request(
231243 ](1 )
232244 self ._response_streams [request_id ] = response_stream
233245
246+ # Set up progress token if progress callback is provided
247+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
248+ if progress_callback is not None :
249+ # Use request_id as progress token
250+ if "params" not in request_data :
251+ request_data ["params" ] = {}
252+ if "_meta" not in request_data ["params" ]:
253+ request_data ["params" ]["_meta" ] = {}
254+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
255+ # Store the callback for this request
256+ self ._progress_callbacks [request_id ] = progress_callback
257+
234258 try :
235259 jsonrpc_request = JSONRPCRequest (
236260 jsonrpc = "2.0" ,
237261 id = request_id ,
238- ** request . model_dump ( by_alias = True , mode = "json" , exclude_none = True ) ,
262+ ** request_data ,
239263 )
240264
241- # TODO: Support progress callbacks
242-
243265 await self ._write_stream .send (
244266 SessionMessage (
245267 message = JSONRPCMessage (jsonrpc_request ), metadata = metadata
@@ -275,6 +297,7 @@ async def send_request(
275297
276298 finally :
277299 self ._response_streams .pop (request_id , None )
300+ self ._progress_callbacks .pop (request_id , None )
278301 await response_stream .aclose ()
279302 await response_stream_reader .aclose ()
280303
@@ -333,7 +356,6 @@ async def _receive_loop(self) -> None:
333356 by_alias = True , mode = "json" , exclude_none = True
334357 )
335358 )
336-
337359 responder = RequestResponder (
338360 request_id = message .message .root .id ,
339361 request_meta = validated_request .root .params .meta
@@ -363,6 +385,18 @@ async def _receive_loop(self) -> None:
363385 if cancelled_id in self ._in_flight :
364386 await self ._in_flight [cancelled_id ].cancel ()
365387 else :
388+ # Handle progress notifications callback
389+ if isinstance (notification .root , ProgressNotification ):
390+ progress_token = notification .root .params .progressToken
391+ # If there is a progress callback for this token,
392+ # call it with the progress information
393+ if progress_token in self ._progress_callbacks :
394+ callback = self ._progress_callbacks [progress_token ]
395+ await callback (
396+ notification .root .params .progress ,
397+ notification .root .params .total ,
398+ notification .root .params .message ,
399+ )
366400 await self ._received_notification (notification )
367401 await self ._handle_incoming (notification )
368402 except Exception as e :
0 commit comments