diff --git a/src/a2a/server/apps/starlette_app.py b/src/a2a/server/apps/starlette_app.py index 84ef7577..fb7a9908 100644 --- a/src/a2a/server/apps/starlette_app.py +++ b/src/a2a/server/apps/starlette_app.py @@ -45,8 +45,20 @@ logger = logging.getLogger(__name__) -# Register Starlette User as an implementation of a2a.auth.user.User -A2AUser.register(BaseUser) + +class StarletteUserProxy(A2AUser): + """Adapts the Starlette User class to the A2A user representation.""" + + def __init__(self, user: BaseUser): + self._user = user + + @property + def is_authenticated(self): + return self._user.is_authenticated + + @property + def user_name(self): + return self._user.display_name class CallContextBuilder(ABC): @@ -64,7 +76,7 @@ def build(self, request: Request) -> ServerCallContext: user = UnauthenticatedUser() state = {} with contextlib.suppress(Exception): - user = request.user + user = StarletteUserProxy(request.user) state['auth'] = request.auth return ServerCallContext(user=user, state=state) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 4493d300..9f45e816 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -5,6 +5,15 @@ import pytest +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + BaseUser, + SimpleUser, +) +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection from starlette.responses import JSONResponse from starlette.routing import Route from starlette.testclient import TestClient @@ -18,8 +27,12 @@ InternalError, InvalidRequestError, JSONParseError, + Message, Part, PushNotificationConfig, + Role, + SendMessageResponse, + SendMessageSuccessResponse, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, @@ -121,9 +134,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock): @pytest.fixture -def client(app: A2AStarletteApplication): +def client(app: A2AStarletteApplication, **kwargs): """Create a test client with the app.""" - return TestClient(app.build()) + return TestClient(app.build(**kwargs)) # === BASIC FUNCTIONALITY TESTS === @@ -249,7 +262,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock): mock_task = Task( id='task1', contextId='session-xyz', - state='completed', status=task_status, ) handler.on_message_send.return_value = mock_task @@ -418,6 +430,67 @@ def test_get_push_notification_config( handler.on_get_task_push_notification_config.assert_awaited_once() +def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): + class TestAuthMiddleware(AuthenticationBackend): + async def authenticate( + self, conn: HTTPConnection + ) -> tuple[AuthCredentials, BaseUser] | None: + # For the purposes of this test, all requests are authenticated! + return (AuthCredentials(['authenticated']), SimpleUser('test_user')) + + client = TestClient( + app.build( + middleware=[ + Middleware( + AuthenticationMiddleware, backend=TestAuthMiddleware() + ) + ] + ) + ) + + # Set the output message to be the authenticated user name + handler.on_message_send.side_effect = lambda params, context: Message( + contextId='session-xyz', + messageId='112', + role=Role.agent, + parts=[ + Part(TextPart(text=context.user.user_name)), + ], + ) + + # Send request + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'messageId': '111', + 'kind': 'message', + 'taskId': 'task1', + 'contextId': 'session-xyz', + } + }, + }, + ) + + # Verify response + assert response.status_code == 200 + result = SendMessageResponse.model_validate(response.json()) + assert isinstance(result.root, SendMessageSuccessResponse) + assert isinstance(result.root.result, Message) + message = result.root.result + assert isinstance(message.parts[0].root, TextPart) + assert message.parts[0].root.text == 'test_user' + + # Verify handler was called + handler.on_message_send.assert_awaited_once() + + # === STREAMING TESTS ===