Skip to content

Commit 8745e68

Browse files
committed
Serialize mcp tests
1 parent 050919d commit 8745e68

File tree

1 file changed

+36
-37
lines changed

1 file changed

+36
-37
lines changed

encoderfile-core/tests/test_mcp.rs

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,10 @@
1-
const LOCALHOST: &str = "localhost";
2-
31
use anyhow::Result;
4-
use encoderfile_core::{
5-
AppState,
6-
dev_utils::embedding_state,
7-
transport::mcp,
8-
};
9-
use rmcp::{
10-
ServiceExt,
11-
model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation},
12-
transport::StreamableHttpClientTransport,
13-
};
2+
use encoderfile_core::{AppState, transport::mcp};
143
use tokio::net::TcpListener;
15-
use tower_http::trace::DefaultOnResponse;
164
use tokio::sync::oneshot;
5+
use tower_http::trace::DefaultOnResponse;
176

18-
async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, done_sender: oneshot::Sender<()>) -> Result<()> {
7+
async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>) -> Result<()> {
198
let model_type = state.model_type.clone();
209
let router = mcp::make_router(state).layer(
2110
tower_http::trace::TraceLayer::new_for_http()
@@ -26,45 +15,43 @@ async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>,
2615
tracing::info!("Running {:?} MCP server on {}", model_type, &addr);
2716
let listener = TcpListener::bind(addr).await?;
2817
axum::serve(listener, router)
29-
.with_graceful_shutdown(
30-
async {
31-
receiver.await;
32-
tracing::info!("Received shutdown signal, shutting down");
33-
done_sender.send(());
34-
()
35-
})
36-
.await;
18+
.with_graceful_shutdown(async {
19+
receiver.await.ok();
20+
tracing::info!("Received shutdown signal, shutting down");
21+
()
22+
})
23+
.await
24+
.expect("Error while shutting down server");
3725
Ok(())
3826
}
3927

4028
macro_rules! test_mcp_server_impl {
4129
($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => {
42-
mod $mod_name {
30+
pub mod $mod_name {
4331
use encoderfile_core::{
4432
common::{$req_type, $resp_type},
4533
dev_utils::$state_func,
4634
};
4735
use rmcp::{
4836
ServiceExt,
49-
transport::StreamableHttpClientTransport,
5037
model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation},
38+
transport::StreamableHttpClientTransport,
5139
};
5240
use tokio::sync::oneshot;
5341

5442
const LOCALHOST: &str = "localhost";
5543
const PORT: i32 = 9100;
5644

57-
#[tokio::test]
58-
#[test_log::test]
59-
async fn $mod_name() {
45+
pub async fn $mod_name() {
6046
let addr = format!("{}:{}", LOCALHOST, PORT);
6147
let dummy_state = $state_func();
6248
let (sender, receiver) = oneshot::channel();
63-
let (done_sender, done_receiver) = oneshot::channel();
64-
let mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver, done_sender));
49+
let _mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver));
6550
// Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66-
let client_transport =
67-
StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, PORT));
51+
let client_transport = StreamableHttpClientTransport::from_uri(format!(
52+
"http://{}:{}/mcp",
53+
LOCALHOST, PORT
54+
));
6855
let client_info = ClientInfo {
6956
protocol_version: Default::default(),
7057
capabilities: ClientCapabilities::default(),
@@ -119,15 +106,13 @@ macro_rules! test_mcp_server_impl {
119106
)
120107
.expect("failed to parse tool result");
121108
assert_eq!(embeddings_response.results.len(), 2);
122-
client.cancel().await;
123-
sender.send(());
124-
done_receiver.await;
109+
client.cancel().await.expect("Error cancelling the agent");
110+
sender.send(()).expect("Error sending end of test signal");
125111
}
126112
}
127-
}
113+
};
128114
}
129115

130-
131116
test_mcp_server_impl!(
132117
test_mcp_embedding,
133118
embedding_state,
@@ -148,9 +133,23 @@ test_mcp_server_impl!(
148133
TokenClassificationRequest,
149134
TokenClassificationResponse
150135
);
136+
151137
test_mcp_server_impl!(
152138
test_mcp_sequence_classification,
153139
sequence_classification_state,
154140
SequenceClassificationRequest,
155141
SequenceClassificationResponse
156-
);
142+
);
143+
144+
#[tokio::test]
145+
#[test_log::test]
146+
async fn test_mcp_servers() {
147+
self::test_mcp_embedding::test_mcp_embedding().await;
148+
tracing::info!("Testing embedding");
149+
self::test_mcp_sentence_embedding::test_mcp_sentence_embedding().await;
150+
tracing::info!("Testing sentence embedding");
151+
self::test_mcp_token_classification::test_mcp_token_classification().await;
152+
tracing::info!("Testing token classification");
153+
self::test_mcp_sequence_classification::test_mcp_sequence_classification().await;
154+
tracing::info!("Testing sequence classification");
155+
}

0 commit comments

Comments
 (0)