1- const LOCALHOST : & str = "localhost" ;
2-
31use anyhow:: Result ;
4- use encoderfile_core:: {
5- AppState ,
6- dev_utils:: embedding_state,
7- transport:: mcp,
8- } ;
9- use rmcp:: {
10- ServiceExt ,
11- model:: { CallToolRequestParam , ClientCapabilities , ClientInfo , Implementation } ,
12- transport:: StreamableHttpClientTransport ,
13- } ;
2+ use encoderfile_core:: { AppState , transport:: mcp} ;
143use tokio:: net:: TcpListener ;
15- use tower_http:: trace:: DefaultOnResponse ;
164use tokio:: sync:: oneshot;
5+ use tower_http:: trace:: DefaultOnResponse ;
176
18- async fn run_mcp ( addr : String , state : AppState , receiver : oneshot:: Receiver < ( ) > , done_sender : oneshot :: Sender < ( ) > ) -> Result < ( ) > {
7+ async fn run_mcp ( addr : String , state : AppState , receiver : oneshot:: Receiver < ( ) > ) -> Result < ( ) > {
198 let model_type = state. model_type . clone ( ) ;
209 let router = mcp:: make_router ( state) . layer (
2110 tower_http:: trace:: TraceLayer :: new_for_http ( )
@@ -26,45 +15,43 @@ async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>,
2615 tracing:: info!( "Running {:?} MCP server on {}" , model_type, & addr) ;
2716 let listener = TcpListener :: bind ( addr) . await ?;
2817 axum:: serve ( listener, router)
29- . with_graceful_shutdown (
30- async {
31- receiver. await ;
32- tracing:: info!( "Received shutdown signal, shutting down" ) ;
33- done_sender. send ( ( ) ) ;
34- ( )
35- } )
36- . await ;
18+ . with_graceful_shutdown ( async {
19+ receiver. await . ok ( ) ;
20+ tracing:: info!( "Received shutdown signal, shutting down" ) ;
21+ ( )
22+ } )
23+ . await
24+ . expect ( "Error while shutting down server" ) ;
3725 Ok ( ( ) )
3826}
3927
4028macro_rules! test_mcp_server_impl {
4129 ( $mod_name: ident, $state_func: ident, $req_type: ident, $resp_type: ident) => {
42- mod $mod_name {
30+ pub mod $mod_name {
4331 use encoderfile_core:: {
4432 common:: { $req_type, $resp_type} ,
4533 dev_utils:: $state_func,
4634 } ;
4735 use rmcp:: {
4836 ServiceExt ,
49- transport:: StreamableHttpClientTransport ,
5037 model:: { CallToolRequestParam , ClientCapabilities , ClientInfo , Implementation } ,
38+ transport:: StreamableHttpClientTransport ,
5139 } ;
5240 use tokio:: sync:: oneshot;
5341
5442 const LOCALHOST : & str = "localhost" ;
5543 const PORT : i32 = 9100 ;
5644
57- #[ tokio:: test]
58- #[ test_log:: test]
59- async fn $mod_name( ) {
45+ pub async fn $mod_name( ) {
6046 let addr = format!( "{}:{}" , LOCALHOST , PORT ) ;
6147 let dummy_state = $state_func( ) ;
6248 let ( sender, receiver) = oneshot:: channel( ) ;
63- let ( done_sender, done_receiver) = oneshot:: channel( ) ;
64- let mcp_server = tokio:: spawn( super :: run_mcp( addr, dummy_state, receiver, done_sender) ) ;
49+ let _mcp_server = tokio:: spawn( super :: run_mcp( addr, dummy_state, receiver) ) ;
6550 // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66- let client_transport =
67- StreamableHttpClientTransport :: from_uri( format!( "http://{}:{}/mcp" , LOCALHOST , PORT ) ) ;
51+ let client_transport = StreamableHttpClientTransport :: from_uri( format!(
52+ "http://{}:{}/mcp" ,
53+ LOCALHOST , PORT
54+ ) ) ;
6855 let client_info = ClientInfo {
6956 protocol_version: Default :: default ( ) ,
7057 capabilities: ClientCapabilities :: default ( ) ,
@@ -119,15 +106,13 @@ macro_rules! test_mcp_server_impl {
119106 )
120107 . expect( "failed to parse tool result" ) ;
121108 assert_eq!( embeddings_response. results. len( ) , 2 ) ;
122- client. cancel( ) . await ;
123- sender. send( ( ) ) ;
124- done_receiver. await ;
109+ client. cancel( ) . await . expect( "Error cancelling the agent" ) ;
110+ sender. send( ( ) ) . expect( "Error sending end of test signal" ) ;
125111 }
126112 }
127- }
113+ } ;
128114}
129115
130-
131116test_mcp_server_impl ! (
132117 test_mcp_embedding,
133118 embedding_state,
@@ -148,9 +133,23 @@ test_mcp_server_impl!(
148133 TokenClassificationRequest ,
149134 TokenClassificationResponse
150135) ;
136+
151137test_mcp_server_impl ! (
152138 test_mcp_sequence_classification,
153139 sequence_classification_state,
154140 SequenceClassificationRequest ,
155141 SequenceClassificationResponse
156- ) ;
142+ ) ;
143+
144+ #[ tokio:: test]
145+ #[ test_log:: test]
146+ async fn test_mcp_servers ( ) {
147+ self :: test_mcp_embedding:: test_mcp_embedding ( ) . await ;
148+ tracing:: info!( "Testing embedding" ) ;
149+ self :: test_mcp_sentence_embedding:: test_mcp_sentence_embedding ( ) . await ;
150+ tracing:: info!( "Testing sentence embedding" ) ;
151+ self :: test_mcp_token_classification:: test_mcp_token_classification ( ) . await ;
152+ tracing:: info!( "Testing token classification" ) ;
153+ self :: test_mcp_sequence_classification:: test_mcp_sequence_classification ( ) . await ;
154+ tracing:: info!( "Testing sequence classification" ) ;
155+ }
0 commit comments