@@ -3,7 +3,6 @@ const LOCALHOST: &str = "localhost";
33use anyhow:: Result ;
44use encoderfile_core:: {
55 AppState ,
6- common:: { EmbeddingRequest , EmbeddingResponse , ModelType } ,
76 dev_utils:: embedding_state,
87 transport:: mcp,
98} ;
@@ -14,8 +13,9 @@ use rmcp::{
1413} ;
1514use tokio:: net:: TcpListener ;
1615use tower_http:: trace:: DefaultOnResponse ;
16+ use tokio:: sync:: oneshot;
1717
18- async fn run_mcp ( addr : String , state : AppState ) -> Result < ( ) > {
18+ async fn run_mcp ( addr : String , state : AppState , receiver : oneshot :: Receiver < ( ) > , done_sender : oneshot :: Sender < ( ) > ) -> Result < ( ) > {
1919 let model_type = state. model_type . clone ( ) ;
2020 let router = mcp:: make_router ( state) . layer (
2121 tower_http:: trace:: TraceLayer :: new_for_http ( )
@@ -25,74 +25,132 @@ async fn run_mcp(addr: String, state: AppState) -> Result<()> {
2525 ) ;
2626 tracing:: info!( "Running {:?} MCP server on {}" , model_type, & addr) ;
2727 let listener = TcpListener :: bind ( addr) . await ?;
28- axum:: serve ( listener, router) . await ?;
28+ axum:: serve ( listener, router)
29+ . with_graceful_shutdown (
30+ async {
31+ receiver. await ;
32+ tracing:: info!( "Received shutdown signal, shutting down" ) ;
33+ done_sender. send ( ( ) ) ;
34+ ( )
35+ } )
36+ . await ;
2937 Ok ( ( ) )
3038}
3139
32- #[ tokio:: test]
33- #[ test_log:: test]
34- async fn test_mcp ( ) {
35- let port = 9100 ;
36- let addr = format ! ( "{}:{}" , LOCALHOST , port) ;
37- let dummy_state = embedding_state ( ) ;
38- let mcp_server = tokio:: spawn ( run_mcp ( addr, dummy_state) ) ;
39- // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
40- let client_transport =
41- StreamableHttpClientTransport :: from_uri ( format ! ( "http://{}:{}/mcp" , LOCALHOST , port) ) ;
42- let client_info = ClientInfo {
43- protocol_version : Default :: default ( ) ,
44- capabilities : ClientCapabilities :: default ( ) ,
45- client_info : Implementation {
46- name : "test sse client" . to_string ( ) ,
47- title : None ,
48- version : "0.0.1" . to_string ( ) ,
49- website_url : None ,
50- icons : None ,
51- } ,
52- } ;
53- let client = client_info
54- . serve ( client_transport)
55- . await
56- . inspect_err ( |e| {
57- tracing:: error!( "client error: {:?}" , e) ;
58- } )
59- . unwrap ( ) ;
60- // Initialize
61- let server_info = client. peer_info ( ) ;
62- tracing:: info!( "Connected to server: {server_info:#?}" ) ;
40+ macro_rules! test_mcp_server_impl {
41+ ( $mod_name: ident, $state_func: ident, $req_type: ident, $resp_type: ident) => {
42+ mod $mod_name {
43+ use encoderfile_core:: {
44+ common:: { $req_type, $resp_type} ,
45+ dev_utils:: $state_func,
46+ } ;
47+ use rmcp:: {
48+ ServiceExt ,
49+ transport:: StreamableHttpClientTransport ,
50+ model:: { CallToolRequestParam , ClientCapabilities , ClientInfo , Implementation } ,
51+ } ;
52+ use tokio:: sync:: oneshot;
6353
64- // List tools
65- let tools = client
66- . list_tools ( Default :: default ( ) )
67- . await
68- . expect ( "list tools failed" ) ;
69- tracing:: info!( "Available tools: {tools:#?}" ) ;
54+ const LOCALHOST : & str = "localhost" ;
55+ const PORT : i32 = 9100 ;
7056
71- assert_eq ! ( tools. tools. len( ) , 1 ) ;
72- assert_eq ! ( tools. tools[ 0 ] . name, "run_encoder" ) ;
57+ #[ tokio:: test]
58+ #[ test_log:: test]
59+ async fn $mod_name( ) {
60+ let addr = format!( "{}:{}" , LOCALHOST , PORT ) ;
61+ let dummy_state = $state_func( ) ;
62+ let ( sender, receiver) = oneshot:: channel( ) ;
63+ let ( done_sender, done_receiver) = oneshot:: channel( ) ;
64+ let mcp_server = tokio:: spawn( super :: run_mcp( addr, dummy_state, receiver, done_sender) ) ;
65+ // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66+ let client_transport =
67+ StreamableHttpClientTransport :: from_uri( format!( "http://{}:{}/mcp" , LOCALHOST , PORT ) ) ;
68+ let client_info = ClientInfo {
69+ protocol_version: Default :: default ( ) ,
70+ capabilities: ClientCapabilities :: default ( ) ,
71+ client_info: Implementation {
72+ name: "test sse client" . to_string( ) ,
73+ title: None ,
74+ version: "0.0.1" . to_string( ) ,
75+ website_url: None ,
76+ icons: None ,
77+ } ,
78+ } ;
79+ let client = client_info
80+ . serve( client_transport)
81+ . await
82+ . inspect_err( |e| {
83+ tracing:: error!( "client error: {:?}" , e) ;
84+ } )
85+ . unwrap( ) ;
86+ // Initialize
87+ let server_info = client. peer_info( ) ;
88+ tracing:: info!( "Connected to server: {server_info:#?}" ) ;
7389
74- let test_params = EmbeddingRequest {
75- inputs : vec ! [
76- "This is a test." . to_string( ) ,
77- "This is another test." . to_string( ) ,
78- ] ,
79- metadata : None ,
80- } ;
81- let tool_result = client
82- . call_tool ( CallToolRequestParam {
83- name : "run_encoder" . into ( ) ,
84- arguments : serde_json:: json!( test_params) . as_object ( ) . cloned ( ) ,
85- } )
86- . await
87- . expect ( "call tool failed" ) ;
88- tracing:: info!( "Tool result: {tool_result:#?}" ) ;
89- let embeddings_response: EmbeddingResponse = serde_json:: from_value (
90- tool_result
91- . structured_content
92- . expect ( "No structured content found" ) ,
93- )
94- . expect ( "failed to parse tool result" ) ;
95- assert_eq ! ( embeddings_response. results. len( ) , 2 ) ;
96- client. cancel ( ) . await . unwrap ( ) ;
97- mcp_server. abort ( ) ;
90+ // List tools
91+ let tools = client
92+ . list_tools( Default :: default ( ) )
93+ . await
94+ . expect( "list tools failed" ) ;
95+ tracing:: info!( "Available tools: {tools:#?}" ) ;
96+
97+ assert_eq!( tools. tools. len( ) , 1 ) ;
98+ assert_eq!( tools. tools[ 0 ] . name, "run_encoder" ) ;
99+
100+ let test_params = $req_type {
101+ inputs: vec![
102+ "This is a test." . to_string( ) ,
103+ "This is another test." . to_string( ) ,
104+ ] ,
105+ metadata: None ,
106+ } ;
107+ let tool_result = client
108+ . call_tool( CallToolRequestParam {
109+ name: "run_encoder" . into( ) ,
110+ arguments: serde_json:: json!( test_params) . as_object( ) . cloned( ) ,
111+ } )
112+ . await
113+ . expect( "call tool failed" ) ;
114+ tracing:: info!( "Tool result: {tool_result:#?}" ) ;
115+ let embeddings_response: $resp_type = serde_json:: from_value(
116+ tool_result
117+ . structured_content
118+ . expect( "No structured content found" ) ,
119+ )
120+ . expect( "failed to parse tool result" ) ;
121+ assert_eq!( embeddings_response. results. len( ) , 2 ) ;
122+ client. cancel( ) . await ;
123+ sender. send( ( ) ) ;
124+ done_receiver. await ;
125+ }
126+ }
127+ }
98128}
129+
130+
131+ test_mcp_server_impl ! (
132+ test_mcp_embedding,
133+ embedding_state,
134+ EmbeddingRequest ,
135+ EmbeddingResponse
136+ ) ;
137+
138+ test_mcp_server_impl ! (
139+ test_mcp_sentence_embedding,
140+ sentence_embedding_state,
141+ SentenceEmbeddingRequest ,
142+ SentenceEmbeddingResponse
143+ ) ;
144+
145+ test_mcp_server_impl ! (
146+ test_mcp_token_classification,
147+ token_classification_state,
148+ TokenClassificationRequest ,
149+ TokenClassificationResponse
150+ ) ;
151+ test_mcp_server_impl ! (
152+ test_mcp_sequence_classification,
153+ sequence_classification_state,
154+ SequenceClassificationRequest ,
155+ SequenceClassificationResponse
156+ ) ;
0 commit comments