Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,37 @@ class SseServerTransport:
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
messages to the relative path given.

Args:
endpoint: A relative path where messages should be posted
(e.g., "/messages/").

Note:
We use relative paths instead of full URLs for several reasons:
1. Security: Prevents cross-origin requests by ensuring clients only connect
to the same origin they established the SSE connection with
2. Flexibility: The server can be mounted at any path without needing to
know its full URL
3. Portability: The same endpoint configuration works across different
environments (development, staging, production)

Raises:
ValueError: If the endpoint is a full URL instead of a relative path
"""

super().__init__()

# Validate that endpoint is a relative path and not a full URL
if "://" in endpoint or endpoint.startswith("//"):
raise ValueError(
"Endpoint must be a relative path (e.g., '/messages/'), not a full URL."
)

# Ensure endpoint starts with a forward slash
if not endpoint.startswith("/"):
endpoint = "/" + endpoint

self._endpoint = endpoint
self._read_stream_writers = {}
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
Expand Down
5 changes: 1 addition & 4 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import httpx
import pytest
from inline_snapshot import snapshot
from pydantic import AnyHttpUrl

from mcp.client.auth import OAuthClientProvider
Expand Down Expand Up @@ -968,8 +967,7 @@ def test_build_metadata(
revocation_options=RevocationOptions(enabled=True),
)

assert metadata == snapshot(
OAuthMetadata(
assert metadata == OAuthMetadata(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you were having?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I run uv run pytest I get the following error:
image

As a fix I tried running:
uv run pytest --inline-snapshot=fix tests/client/test_auth.py but got the same error again, alternatively I tried using Is for parameterised variables, but that did not seem like fully correct either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few observations

  1. This test is passing in GitHub workflow
  2. In local, when I delete pycache and run uv run pytest tests/client/test_auth.py it passes
  3. In local, when I run uv run pytest it fails again

Since there does not seem to be an issue in the workflow I have reverted this change, or alternatively we can create explicit snapshot for each parameter similar to mentioned here.

issuer=AnyHttpUrl(issuer_url),
authorization_endpoint=AnyHttpUrl(authorization_endpoint),
token_endpoint=AnyHttpUrl(token_endpoint),
Expand All @@ -982,4 +980,3 @@ def test_build_metadata(
revocation_endpoint_auth_methods_supported=["client_secret_post"],
code_challenge_methods_supported=["S256"],
)
)
23 changes: 16 additions & 7 deletions tests/issues/test_188_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,38 @@
@pytest.mark.anyio
async def test_messages_are_executed_concurrently():
server = FastMCP("test")

call_timestamps = []

@server.tool("sleep")
async def sleep_tool():
call_timestamps.append(("tool_start_time", anyio.current_time()))
await anyio.sleep(_sleep_time_seconds)
call_timestamps.append(("tool_end_time", anyio.current_time()))
return "done"

@server.resource(_resource_name)
async def slow_resource():
call_timestamps.append(("resource_start_time", anyio.current_time()))
await anyio.sleep(_sleep_time_seconds)
call_timestamps.append(("resource_end_time", anyio.current_time()))
return "slow"

async with create_session(server._mcp_server) as client_session:
start_time = anyio.current_time()
async with anyio.create_task_group() as tg:
for _ in range(10):
tg.start_soon(client_session.call_tool, "sleep")
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))

end_time = anyio.current_time()

duration = end_time - start_time
assert duration < 6 * _sleep_time_seconds
print(duration)
active_calls = 0
max_concurrent_calls = 0
for call_type, _ in sorted(call_timestamps, key=lambda x: x[1]):
if "start" in call_type:
active_calls += 1
max_concurrent_calls = max(max_concurrent_calls, active_calls)
else:
active_calls -= 1
print(f"Max concurrent calls: {max_concurrent_calls}")
assert max_concurrent_calls > 1, "No concurrent calls were executed"


def main():
Expand Down
Loading