diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index dba7d048..6ac4140d 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -75,18 +75,59 @@ where } /// Helper function to expect a response from the stream -async fn expect_response( +async fn expect_response( transport: &mut T, context: &str, + service: &S, + peer: Peer, ) -> Result<(ServerResult, RequestId), ClientInitializeError> where T: Transport, + S: Service, { - let msg = expect_next_message(transport, context).await?; - - match msg { - ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)), - _ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))), + loop { + let message = expect_next_message(transport, context).await?; + match message { + // Expected message to complete the initialization + ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { + break Ok((result, id)); + } + // Server could send logging messages before handshake + ServerJsonRpcMessage::Notification(mut notification) => { + let ServerNotification::LoggingMessageNotification(logging) = + &mut notification.notification + else { + tracing::warn!(?notification, "Received unexpected message"); + continue; + }; + + let mut context = NotificationContext { + peer: peer.clone(), + meta: Meta::default(), + extensions: Extensions::default(), + }; + + if let Some(meta) = logging.extensions.get_mut::() { + std::mem::swap(&mut context.meta, meta); + } + std::mem::swap(&mut context.extensions, &mut logging.extensions); + + if let Err(error) = service + .handle_notification(notification.notification, context) + .await + { + tracing::warn!(?error, "Handle logging before handshake failed."); + } + } + // Server could send pings before handshake + ServerJsonRpcMessage::Request(ref request) + if matches!(request.request, ServerRequest::PingRequest(_)) => + { + tracing::trace!("Received ping request. Ignored.") + } + // Server SHOULD NOT send any other messages before handshake. We ignore them anyway + _ => tracing::warn!(?message, "Received unexpected message"), + } } } @@ -183,7 +224,15 @@ where context: "send initialize request".into(), })?; - let (response, response_id) = expect_response(&mut transport, "initialize response").await?; + let (peer, peer_rx) = Peer::new(id_provider, None); + + let (response, response_id) = expect_response( + &mut transport, + "initialize response", + &service, + peer.clone(), + ) + .await?; if id != response_id { return Err(ClientInitializeError::ConflictInitResponseId( @@ -195,6 +244,7 @@ where let ServerResult::InitializeResult(initialize_result) = response else { return Err(ClientInitializeError::ExpectedInitResult(Some(response))); }; + peer.set_peer_info(initialize_result); // send notification let notification = ClientJsonRpcMessage::notification( @@ -206,7 +256,6 @@ where transport.send(notification).await.map_err(|error| { ClientInitializeError::transport::(error, "send initialized notification") })?; - let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result)); Ok(serve_inner(service, transport, peer, peer_rx, ct)) }