Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,59 @@ where
}

/// Helper function to expect a response from the stream
async fn expect_response<T>(
async fn expect_response<T, S>(
transport: &mut T,
context: &str,
service: &S,
peer: Peer<RoleClient>,
) -> Result<(ServerResult, RequestId), ClientInitializeError>
where
T: Transport<RoleClient>,
S: Service<RoleClient>,
{
let msg = expect_next_message(transport, context).await?;

match msg {
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
_ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
loop {
let message = expect_next_message(transport, context).await?;
match message {
// Expected message to complete the initialization
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
break Ok((result, id));
}
// Server could send logging messages before handshake
ServerJsonRpcMessage::Notification(mut notification) => {
let ServerNotification::LoggingMessageNotification(logging) =
&mut notification.notification
else {
tracing::warn!(?notification, "Received unexpected message");
continue;
};

let mut context = NotificationContext {
peer: peer.clone(),
meta: Meta::default(),
extensions: Extensions::default(),
};

if let Some(meta) = logging.extensions.get_mut::<Meta>() {
std::mem::swap(&mut context.meta, meta);
}
std::mem::swap(&mut context.extensions, &mut logging.extensions);

if let Err(error) = service
.handle_notification(notification.notification, context)
.await
{
tracing::warn!(?error, "Handle logging before handshake failed.");
}
}
// Server could send pings before handshake
ServerJsonRpcMessage::Request(ref request)
if matches!(request.request, ServerRequest::PingRequest(_)) =>
{
tracing::trace!("Received ping request. Ignored.")
}
// Server SHOULD NOT send any other messages before handshake. We ignore them anyway
_ => tracing::warn!(?message, "Received unexpected message"),
}
}
}

Expand Down Expand Up @@ -183,7 +224,15 @@ where
context: "send initialize request".into(),
})?;

let (response, response_id) = expect_response(&mut transport, "initialize response").await?;
let (peer, peer_rx) = Peer::new(id_provider, None);

let (response, response_id) = expect_response(
&mut transport,
"initialize response",
&service,
peer.clone(),
)
.await?;

if id != response_id {
return Err(ClientInitializeError::ConflictInitResponseId(
Expand All @@ -195,6 +244,7 @@ where
let ServerResult::InitializeResult(initialize_result) = response else {
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
};
peer.set_peer_info(initialize_result);

// send notification
let notification = ClientJsonRpcMessage::notification(
Expand All @@ -206,7 +256,6 @@ where
transport.send(notification).await.map_err(|error| {
ClientInitializeError::transport::<T>(error, "send initialized notification")
})?;
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}

Expand Down