Skip to content

Commit 050919d

Browse files
committed
Add templates for all 4 servers (fails now)
1 parent 66c45af commit 050919d

File tree

6 files changed

+142
-74
lines changed

6 files changed

+142
-74
lines changed

Cargo.lock

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encoderfile-core/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ rand = "0.9.2"
7272
tower = "0.5.2"
7373
test-log = "0.2.18"
7474

75+
[dev-dependencies.hyper-util]
76+
version = "0.1.18"
77+
features = ["server-graceful"]
78+
79+
[dev-dependencies.hyper]
80+
version = "1.8.1"
81+
features = ["http1"]
82+
7583
[dev-dependencies.rmcp]
7684
version = "0.8.0"
7785
features = ["client", "transport-streamable-http-client-reqwest"]

encoderfile-core/src/common/sentence_embedding.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub struct SentenceEmbeddingRequest {
1010
pub metadata: Option<HashMap<String, String>>,
1111
}
1212

13-
#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)]
13+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)]
1414
pub struct SentenceEmbeddingResponse {
1515
pub results: Vec<SentenceEmbedding>,
1616
pub model_id: String,

encoderfile-core/src/common/sequence_classification.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ pub struct SequenceClassificationRequest {
1010
pub metadata: Option<HashMap<String, String>>,
1111
}
1212

13-
#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)]
13+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)]
1414
pub struct SequenceClassificationResponse {
1515
pub results: Vec<SequenceClassificationResult>,
1616
pub model_id: String,
1717
#[serde(skip_serializing_if = "Option::is_none")]
1818
pub metadata: Option<HashMap<String, String>>,
1919
}
2020

21-
#[derive(Debug, Serialize, ToSchema, JsonSchema)]
21+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)]
2222
pub struct SequenceClassificationResult {
2323
pub logits: Vec<f32>,
2424
pub scores: Vec<f32>,

encoderfile-core/src/common/token_classification.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ pub struct TokenClassificationRequest {
1010
pub metadata: Option<HashMap<String, String>>,
1111
}
1212

13-
#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)]
13+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)]
1414
pub struct TokenClassificationResponse {
1515
pub results: Vec<TokenClassificationResult>,
1616
pub model_id: String,
1717
#[serde(skip_serializing_if = "Option::is_none")]
1818
pub metadata: Option<HashMap<String, String>>,
1919
}
2020

21-
#[derive(Debug, Serialize, ToSchema, JsonSchema)]
21+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)]
2222
pub struct TokenClassificationResult {
2323
pub tokens: Vec<TokenClassification>,
2424
}
2525

26-
#[derive(Debug, Serialize, ToSchema, JsonSchema)]
26+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)]
2727
pub struct TokenClassification {
2828
pub token_info: super::token::TokenInfo,
2929
pub scores: Vec<f32>,

encoderfile-core/tests/test_mcp.rs

Lines changed: 124 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ const LOCALHOST: &str = "localhost";
33
use anyhow::Result;
44
use encoderfile_core::{
55
AppState,
6-
common::{EmbeddingRequest, EmbeddingResponse, ModelType},
76
dev_utils::embedding_state,
87
transport::mcp,
98
};
@@ -14,8 +13,9 @@ use rmcp::{
1413
};
1514
use tokio::net::TcpListener;
1615
use tower_http::trace::DefaultOnResponse;
16+
use tokio::sync::oneshot;
1717

18-
async fn run_mcp(addr: String, state: AppState) -> Result<()> {
18+
async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, done_sender: oneshot::Sender<()>) -> Result<()> {
1919
let model_type = state.model_type.clone();
2020
let router = mcp::make_router(state).layer(
2121
tower_http::trace::TraceLayer::new_for_http()
@@ -25,74 +25,132 @@ async fn run_mcp(addr: String, state: AppState) -> Result<()> {
2525
);
2626
tracing::info!("Running {:?} MCP server on {}", model_type, &addr);
2727
let listener = TcpListener::bind(addr).await?;
28-
axum::serve(listener, router).await?;
28+
axum::serve(listener, router)
29+
.with_graceful_shutdown(
30+
async {
31+
receiver.await;
32+
tracing::info!("Received shutdown signal, shutting down");
33+
done_sender.send(());
34+
()
35+
})
36+
.await;
2937
Ok(())
3038
}
3139

32-
#[tokio::test]
33-
#[test_log::test]
34-
async fn test_mcp() {
35-
let port = 9100;
36-
let addr = format!("{}:{}", LOCALHOST, port);
37-
let dummy_state = embedding_state();
38-
let mcp_server = tokio::spawn(run_mcp(addr, dummy_state));
39-
// Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
40-
let client_transport =
41-
StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, port));
42-
let client_info = ClientInfo {
43-
protocol_version: Default::default(),
44-
capabilities: ClientCapabilities::default(),
45-
client_info: Implementation {
46-
name: "test sse client".to_string(),
47-
title: None,
48-
version: "0.0.1".to_string(),
49-
website_url: None,
50-
icons: None,
51-
},
52-
};
53-
let client = client_info
54-
.serve(client_transport)
55-
.await
56-
.inspect_err(|e| {
57-
tracing::error!("client error: {:?}", e);
58-
})
59-
.unwrap();
60-
// Initialize
61-
let server_info = client.peer_info();
62-
tracing::info!("Connected to server: {server_info:#?}");
40+
macro_rules! test_mcp_server_impl {
41+
($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => {
42+
mod $mod_name {
43+
use encoderfile_core::{
44+
common::{$req_type, $resp_type},
45+
dev_utils::$state_func,
46+
};
47+
use rmcp::{
48+
ServiceExt,
49+
transport::StreamableHttpClientTransport,
50+
model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation},
51+
};
52+
use tokio::sync::oneshot;
6353

64-
// List tools
65-
let tools = client
66-
.list_tools(Default::default())
67-
.await
68-
.expect("list tools failed");
69-
tracing::info!("Available tools: {tools:#?}");
54+
const LOCALHOST: &str = "localhost";
55+
const PORT: i32 = 9100;
7056

71-
assert_eq!(tools.tools.len(), 1);
72-
assert_eq!(tools.tools[0].name, "run_encoder");
57+
#[tokio::test]
58+
#[test_log::test]
59+
async fn $mod_name() {
60+
let addr = format!("{}:{}", LOCALHOST, PORT);
61+
let dummy_state = $state_func();
62+
let (sender, receiver) = oneshot::channel();
63+
let (done_sender, done_receiver) = oneshot::channel();
64+
let mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver, done_sender));
65+
// Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66+
let client_transport =
67+
StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, PORT));
68+
let client_info = ClientInfo {
69+
protocol_version: Default::default(),
70+
capabilities: ClientCapabilities::default(),
71+
client_info: Implementation {
72+
name: "test sse client".to_string(),
73+
title: None,
74+
version: "0.0.1".to_string(),
75+
website_url: None,
76+
icons: None,
77+
},
78+
};
79+
let client = client_info
80+
.serve(client_transport)
81+
.await
82+
.inspect_err(|e| {
83+
tracing::error!("client error: {:?}", e);
84+
})
85+
.unwrap();
86+
// Initialize
87+
let server_info = client.peer_info();
88+
tracing::info!("Connected to server: {server_info:#?}");
7389

74-
let test_params = EmbeddingRequest {
75-
inputs: vec![
76-
"This is a test.".to_string(),
77-
"This is another test.".to_string(),
78-
],
79-
metadata: None,
80-
};
81-
let tool_result = client
82-
.call_tool(CallToolRequestParam {
83-
name: "run_encoder".into(),
84-
arguments: serde_json::json!(test_params).as_object().cloned(),
85-
})
86-
.await
87-
.expect("call tool failed");
88-
tracing::info!("Tool result: {tool_result:#?}");
89-
let embeddings_response: EmbeddingResponse = serde_json::from_value(
90-
tool_result
91-
.structured_content
92-
.expect("No structured content found"),
93-
)
94-
.expect("failed to parse tool result");
95-
assert_eq!(embeddings_response.results.len(), 2);
96-
client.cancel().await.unwrap();
97-
mcp_server.abort();
90+
// List tools
91+
let tools = client
92+
.list_tools(Default::default())
93+
.await
94+
.expect("list tools failed");
95+
tracing::info!("Available tools: {tools:#?}");
96+
97+
assert_eq!(tools.tools.len(), 1);
98+
assert_eq!(tools.tools[0].name, "run_encoder");
99+
100+
let test_params = $req_type {
101+
inputs: vec![
102+
"This is a test.".to_string(),
103+
"This is another test.".to_string(),
104+
],
105+
metadata: None,
106+
};
107+
let tool_result = client
108+
.call_tool(CallToolRequestParam {
109+
name: "run_encoder".into(),
110+
arguments: serde_json::json!(test_params).as_object().cloned(),
111+
})
112+
.await
113+
.expect("call tool failed");
114+
tracing::info!("Tool result: {tool_result:#?}");
115+
let embeddings_response: $resp_type = serde_json::from_value(
116+
tool_result
117+
.structured_content
118+
.expect("No structured content found"),
119+
)
120+
.expect("failed to parse tool result");
121+
assert_eq!(embeddings_response.results.len(), 2);
122+
client.cancel().await;
123+
sender.send(());
124+
done_receiver.await;
125+
}
126+
}
127+
}
98128
}
129+
130+
131+
test_mcp_server_impl!(
132+
test_mcp_embedding,
133+
embedding_state,
134+
EmbeddingRequest,
135+
EmbeddingResponse
136+
);
137+
138+
test_mcp_server_impl!(
139+
test_mcp_sentence_embedding,
140+
sentence_embedding_state,
141+
SentenceEmbeddingRequest,
142+
SentenceEmbeddingResponse
143+
);
144+
145+
test_mcp_server_impl!(
146+
test_mcp_token_classification,
147+
token_classification_state,
148+
TokenClassificationRequest,
149+
TokenClassificationResponse
150+
);
151+
test_mcp_server_impl!(
152+
test_mcp_sequence_classification,
153+
sequence_classification_state,
154+
SequenceClassificationRequest,
155+
SequenceClassificationResponse
156+
);

0 commit comments

Comments
 (0)