Skip to content

feat: add message to ProgressNotification #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ async def send_ping(self) -> types.EmptyResult:
)

async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
Expand All @@ -178,6 +182,7 @@ async def send_progress_notification(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
),
)
Expand Down
8 changes: 6 additions & 2 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,14 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
return self._request_context

async def report_progress(
self, progress: float, total: float | None = None
self, progress: float, total: float | None = None, message: str | None = None
) -> None:
"""Report progress for the current operation.

Args:
progress: Current progress value e.g. 24
total: Optional total value e.g. 100
message: Optional message e.g. Starting render...
"""

progress_token = (
Expand All @@ -640,7 +641,10 @@ async def report_progress(
return

await self.request_context.session.send_progress_notification(
progress_token=progress_token, progress=progress, total=total
progress_token=progress_token,
progress=progress,
total=total,
message=message,
)

async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
Expand Down
12 changes: 9 additions & 3 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
3. Define notification handlers if needed:
@server.progress_notification()
async def handle_progress(
progress_token: str | int, progress: float, total: float | None
progress_token: str | int, progress: float, total: float | None,
message: str | None
) -> None:
# Implementation

Expand Down Expand Up @@ -426,13 +427,18 @@ async def handler(req: types.CallToolRequest):

def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
func: Callable[
[str | int, float, float | None, str | None], Awaitable[None]
],
):
logger.debug("Registering handler for ProgressNotification")

async def handler(req: types.ProgressNotification):
await func(
req.params.progressToken, req.params.progress, req.params.total
req.params.progressToken,
req.params.progress,
req.params.total,
req.params.message,
)

self.notification_handlers[types.ProgressNotification] = handler
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ async def send_ping(self) -> types.EmptyResult:
)

async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
Expand All @@ -272,6 +276,7 @@ async def send_progress_notification(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
)
)
Expand Down
5 changes: 3 additions & 2 deletions src/mcp/shared/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ class ProgressContext(
progress_token: ProgressToken
total: float | None
current: float = field(default=0.0, init=False)
message: str | None

async def progress(self, amount: float) -> None:
self.current += amount

await self.session.send_progress_notification(
self.progress_token, self.current, total=self.total
self.progress_token, self.current, total=self.total, message=self.message
)


Expand Down Expand Up @@ -77,7 +78,7 @@ def progress(
if ctx.meta is None or ctx.meta.progressToken is None:
raise ValueError("No progress token provided")

progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total, None)
try:
yield progress_ctx
finally:
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No
"""

async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""
Sends a progress notification for a request that is currently being
Expand Down
5 changes: 5 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams):
total is unknown.
"""
total: float | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
message: str | None = None
"""Total number of items to process (or total progress required), if known."""
model_config = ConfigDict(extra="allow")

Expand Down
6 changes: 3 additions & 3 deletions tests/issues/test_176_progress_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
mock_session.send_progress_notification.call_count == 3
), "All progress notifications should be sent"
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=0.0, total=10.0
progress_token=0, progress=0.0, total=10.0, message=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=5.0, total=10.0
progress_token=0, progress=5.0, total=10.0, message=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=10.0, total=10.0
progress_token=0, progress=10.0, total=10.0, message=None
)
214 changes: 214 additions & 0 deletions tests/shared/test_progress_notifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import anyio
import pytest

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.session import RequestResponder
from mcp.types import (
JSONRPCMessage,
)


@pytest.mark.anyio
async def test_bidirectional_progress_notifications():
"""Test that both client and server can send progress notifications."""
# Create memory streams for client/server
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage
](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage
](5)

# Run a server session so we can send progress updates in tool
async def run_server():
# Create a server session
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="ProgressTestServer",
server_version="0.1.0",
capabilities=server.get_capabilities(NotificationOptions(), {}),
),
) as server_session:
global serv_sesh

serv_sesh = server_session
async for message in server_session.incoming_messages:
try:
await server._handle_message(message, server_session, ())
except Exception as e:
raise e

# Track progress updates
server_progress_updates = []
client_progress_updates = []

# Progress tokens
server_progress_token = "server_token_123"
client_progress_token = "client_token_456"

# Create a server with progress capability
server = Server(name="ProgressTestServer")

# Register progress handler
@server.progress_notification()
async def handle_progress(
progress_token: str | int,
progress: float,
total: float | None,
message: str | None,
):
server_progress_updates.append(
{
"token": progress_token,
"progress": progress,
"total": total,
"message": message,
}
)

# Register list tool handler
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="test_tool",
description="A tool that sends progress notifications <o/",
inputSchema={},
)
]

# Register tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
# Make sure we received a progress token
if name == "test_tool":
if arguments and "_meta" in arguments:
progressToken = arguments["_meta"]["progressToken"]

if not progressToken:
raise ValueError("Empty progress token received")

if progressToken != client_progress_token:
raise ValueError("Server sending back incorrect progressToken")

# Send progress notifications
await serv_sesh.send_progress_notification(
progress_token=progressToken,
progress=0.25,
total=1.0,
message="Server progress 25%",
)
await anyio.sleep(0.2)

await serv_sesh.send_progress_notification(
progress_token=progressToken,
progress=0.5,
total=1.0,
message="Server progress 50%",
)
await anyio.sleep(0.2)

await serv_sesh.send_progress_notification(
progress_token=progressToken,
progress=1.0,
total=1.0,
message="Server progress 100%",
)

else:
raise ValueError("Progress token not sent.")

return ["Tool executed successfully"]

raise ValueError(f"Unknown tool: {name}")

# Client message handler to store progress notifications
async def handle_client_message(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message

if isinstance(message, types.ServerNotification):
if isinstance(message.root, types.ProgressNotification):
params = message.root.params
client_progress_updates.append(
{
"token": params.progressToken,
"progress": params.progress,
"total": params.total,
"message": params.message,
}
)

# Test using client
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=handle_client_message,
) as client_session,
anyio.create_task_group() as tg,
):
# Start the server in a background task
tg.start_soon(run_server)

# Initialize the client connection
await client_session.initialize()

# Call list_tools with progress token
await client_session.list_tools()

# Call test_tool with progress token
await client_session.call_tool(
"test_tool", {"_meta": {"progressToken": client_progress_token}}
)

# Send progress notifications from client to server
await client_session.send_progress_notification(
progress_token=server_progress_token,
progress=0.33,
total=1.0,
message="Client progress 33%",
)

await client_session.send_progress_notification(
progress_token=server_progress_token,
progress=0.66,
total=1.0,
message="Client progress 66%",
)

await client_session.send_progress_notification(
progress_token=server_progress_token,
progress=1.0,
total=1.0,
message="Client progress 100%",
)

# Wait and exit
await anyio.sleep(1.0)
tg.cancel_scope.cancel()

# Verify client received progress updates from server
assert len(client_progress_updates) == 3
assert client_progress_updates[0]["token"] == client_progress_token
assert client_progress_updates[0]["progress"] == 0.25
assert client_progress_updates[0]["message"] == "Server progress 25%"
assert client_progress_updates[2]["progress"] == 1.0

# Verify server received progress updates from client
assert len(server_progress_updates) == 3
assert server_progress_updates[0]["token"] == server_progress_token
assert server_progress_updates[0]["progress"] == 0.33
assert server_progress_updates[0]["message"] == "Client progress 33%"
assert server_progress_updates[2]["progress"] == 1.0
Loading