11"""Test that cancelled requests don't cause double responses.""" 
22
3- import  asyncio 
4- from  unittest .mock  import  MagicMock 
5- 
3+ import  anyio 
64import  pytest 
75
86import  mcp .types  as  types 
97from  mcp .server .lowlevel .server  import  Server 
10- from  mcp .types  import  PingRequest 
11- 
12- 
13- # Shared mock class 
14- class  MockRequestResponder :
15-     def  __init__ (self ):
16-         self .request_id  =  "test-123" 
17-         self ._responded  =  False 
18-         self .request_meta  =  {}
19-         self .message_metadata  =  None 
20- 
21-     async  def  send (self , response ):
22-         if  self ._responded :
23-             raise  AssertionError (f"Request { self .request_id }   already responded to" )
24-         self ._responded  =  True 
25- 
26-     async  def  respond (self , response ):
27-         await  self .send (response )
28- 
29-     def  cancel (self ):
30-         """Simulate the cancel() method sending an error response.""" 
31-         asyncio .create_task (self .send (types .ErrorData (code = - 32800 , message = "Request cancelled" )))
8+ from  mcp .shared .exceptions  import  McpError 
9+ from  mcp .shared .memory  import  create_connected_server_and_client_session 
10+ from  mcp .types  import  (
11+     CallToolRequest ,
12+     CallToolRequestParams ,
13+     CallToolResult ,
14+     CancelledNotification ,
15+     CancelledNotificationParams ,
16+     ClientNotification ,
17+     ClientRequest ,
18+     Tool ,
19+ )
3220
3321
3422@pytest .mark .anyio  
3523async  def  test_cancelled_request_no_double_response ():
3624    """Verify server handles cancelled requests without double response.""" 
3725
38-     # Create a server instance  
26+     # Create server with a slow tool  
3927    server  =  Server ("test-server" )
4028
41-     # Track if multiple responses are attempted 
42-     response_count  =  0 
43- 
44-     # Override the send method to track calls 
45-     mock_message  =  MockRequestResponder ()
46-     original_send  =  mock_message .send 
47- 
48-     async  def  tracked_send (response ):
49-         nonlocal  response_count 
50-         response_count  +=  1 
51-         await  original_send (response )
52- 
53-     mock_message .send  =  tracked_send 
54- 
55-     # Create a slow handler that will be cancelled 
56-     async  def  slow_handler (req ):
57-         await  asyncio .sleep (10 )
58-         return  types .ServerResult (types .EmptyResult ())
59- 
60-     # Use PingRequest as it's a valid request type 
61-     server .request_handlers [types .PingRequest ] =  slow_handler 
62- 
63-     # Create mock message and session 
64-     mock_req  =  PingRequest (method = "ping" )
65-     mock_session  =  MagicMock ()
66-     mock_context  =  None 
67- 
68-     # Start the request 
69-     handle_task  =  asyncio .create_task (
70-         server ._handle_request (mock_message , mock_req , mock_session , mock_context , raise_exceptions = False )  # type: ignore 
71-     )
72- 
73-     # Give it time to start 
74-     await  asyncio .sleep (0.1 )
75- 
76-     # Simulate cancellation 
77-     mock_message .cancel ()
78-     handle_task .cancel ()
79- 
80-     # Wait for cancellation to propagate 
81-     try :
82-         await  handle_task 
83-     except  asyncio .CancelledError :
84-         pass 
85- 
86-     # Give time for any duplicate response attempts 
87-     await  asyncio .sleep (0.1 )
88- 
89-     # Should only have one response (from cancel()) 
90-     assert  response_count  ==  1 , f"Expected 1 response, got { response_count }  " 
29+     # Track when tool is called 
30+     ev_tool_called  =  anyio .Event ()
31+     request_id  =  None 
32+ 
33+     @server .list_tools () 
34+     async  def  handle_list_tools () ->  list [Tool ]:
35+         return  [
36+             Tool (
37+                 name = "slow_tool" ,
38+                 description = "A slow tool for testing cancellation" ,
39+                 inputSchema = {},
40+             )
41+         ]
42+ 
43+     @server .call_tool () 
44+     async  def  handle_call_tool (name : str , arguments : dict  |  None ) ->  list :
45+         nonlocal  request_id 
46+         if  name  ==  "slow_tool" :
47+             request_id  =  server .request_context .request_id 
48+             ev_tool_called .set ()
49+             await  anyio .sleep (10 )  # Long running operation 
50+             return  [types .TextContent (type = "text" , text = "Tool called" )]
51+         raise  ValueError (f"Unknown tool: { name }  " )
52+ 
53+     # Connect client to server 
54+     async  with  create_connected_server_and_client_session (server ) as  client :
55+         # Start the slow tool call in a separate task 
56+         async  def  make_request ():
57+             try :
58+                 await  client .send_request (
59+                     ClientRequest (
60+                         CallToolRequest (
61+                             method = "tools/call" ,
62+                             params = CallToolRequestParams (name = "slow_tool" , arguments = {}),
63+                         )
64+                     ),
65+                     CallToolResult ,
66+                 )
67+                 pytest .fail ("Request should have been cancelled" )
68+             except  McpError  as  e :
69+                 # Expected - request was cancelled 
70+                 assert  e .error .code  ==  0   # Request cancelled error code 
71+ 
72+         # Start the request 
73+         request_task  =  anyio .create_task_group ()
74+         async  with  request_task :
75+             request_task .start_soon (make_request )
76+ 
77+             # Wait for tool to start executing 
78+             await  ev_tool_called .wait ()
79+ 
80+             # Send cancellation notification 
81+             assert  request_id  is  not   None 
82+             await  client .send_notification (
83+                 ClientNotification (
84+                     CancelledNotification (
85+                         method = "notifications/cancelled" ,
86+                         params = CancelledNotificationParams (
87+                             requestId = request_id ,
88+                             reason = "Test cancellation" ,
89+                         ),
90+                     )
91+                 )
92+             )
93+ 
94+             # The request should be cancelled and raise McpError 
9195
9296
9397@pytest .mark .anyio  
@@ -96,43 +100,87 @@ async def test_server_remains_functional_after_cancel():
96100
97101    server  =  Server ("test-server" )
98102
99-     # Add handlers 
100-     async  def  slow_handler (req ):
101-         await  asyncio .sleep (5 )
102-         return  types .ServerResult (types .EmptyResult ())
103- 
104-     async  def  fast_handler (req ):
105-         return  types .ServerResult (types .EmptyResult ())
106- 
107-     # Override ping handler for our test 
108-     server .request_handlers [types .PingRequest ] =  slow_handler 
109- 
110-     # First request (will be cancelled) 
111-     mock_message1  =  MockRequestResponder ()
112-     mock_req1  =  PingRequest (method = "ping" )
113- 
114-     handle_task  =  asyncio .create_task (
115-         server ._handle_request (mock_message1 , mock_req1 , MagicMock (), None , raise_exceptions = False )  # type: ignore 
116-     )
117- 
118-     await  asyncio .sleep (0.1 )
119-     mock_message1 .cancel ()
120-     handle_task .cancel ()
121- 
122-     try :
123-         await  handle_task 
124-     except  asyncio .CancelledError :
125-         pass 
126- 
127-     # Change handler to fast one 
128-     server .request_handlers [types .PingRequest ] =  fast_handler 
129- 
130-     # Second request (should work normally) 
131-     mock_message2  =  MockRequestResponder ()
132-     mock_req2  =  PingRequest (method = "ping" )
133- 
134-     # This should complete successfully 
135-     await  server ._handle_request (mock_message2 , mock_req2 , MagicMock (), None , raise_exceptions = False )  # type: ignore 
136- 
137-     # Server handled the second request successfully 
138-     assert  mock_message2 ._responded 
103+     # Track tool calls 
104+     call_count  =  0 
105+     ev_first_call  =  anyio .Event ()
106+     first_request_id  =  None 
107+ 
108+     @server .list_tools () 
109+     async  def  handle_list_tools () ->  list [Tool ]:
110+         return  [
111+             Tool (
112+                 name = "test_tool" ,
113+                 description = "Tool for testing" ,
114+                 inputSchema = {},
115+             )
116+         ]
117+ 
118+     @server .call_tool () 
119+     async  def  handle_call_tool (name : str , arguments : dict  |  None ) ->  list :
120+         nonlocal  call_count , first_request_id 
121+         if  name  ==  "test_tool" :
122+             call_count  +=  1 
123+             if  call_count  ==  1 :
124+                 first_request_id  =  server .request_context .request_id 
125+                 ev_first_call .set ()
126+                 await  anyio .sleep (5 )  # First call is slow 
127+             return  [types .TextContent (type = "text" , text = f"Call number: { call_count }  " )]
128+         raise  ValueError (f"Unknown tool: { name }  " )
129+ 
130+     async  with  create_connected_server_and_client_session (server ) as  client :
131+         # First request (will be cancelled) 
132+         async  def  first_request ():
133+             try :
134+                 await  client .send_request (
135+                     ClientRequest (
136+                         CallToolRequest (
137+                             method = "tools/call" ,
138+                             params = CallToolRequestParams (name = "test_tool" , arguments = {}),
139+                         )
140+                     ),
141+                     CallToolResult ,
142+                 )
143+                 pytest .fail ("First request should have been cancelled" )
144+             except  McpError :
145+                 pass   # Expected 
146+ 
147+         # Start first request 
148+         async  with  anyio .create_task_group () as  tg :
149+             tg .start_soon (first_request )
150+ 
151+             # Wait for it to start 
152+             await  ev_first_call .wait ()
153+ 
154+             # Cancel it 
155+             assert  first_request_id  is  not   None 
156+             await  client .send_notification (
157+                 ClientNotification (
158+                     CancelledNotification (
159+                         method = "notifications/cancelled" ,
160+                         params = CancelledNotificationParams (
161+                             requestId = first_request_id ,
162+                             reason = "Testing server recovery" ,
163+                         ),
164+                     )
165+                 )
166+             )
167+ 
168+         # Second request (should work normally) 
169+         result  =  await  client .send_request (
170+             ClientRequest (
171+                 CallToolRequest (
172+                     method = "tools/call" ,
173+                     params = CallToolRequestParams (name = "test_tool" , arguments = {}),
174+                 )
175+             ),
176+             CallToolResult ,
177+         )
178+ 
179+         # Verify second request completed successfully 
180+         assert  len (result .content ) ==  1 
181+         # Type narrowing for pyright 
182+         content  =  result .content [0 ]
183+         assert  content .type  ==  "text" 
184+         assert  isinstance (content , types .TextContent )
185+         assert  content .text  ==  "Call number: 2" 
186+         assert  call_count  ==  2 
0 commit comments