@@ -1669,9 +1669,7 @@ pub async fn run(
16691669 ApiDoc :: openapi ( )
16701670 } ;
16711671
1672- // Create router
1673- let mut app = Router :: new ( )
1674- . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc) )
1672+ let mut routes = Router :: new ( )
16751673 // Base routes
16761674 . route ( "/info" , get ( get_model_info) )
16771675 . route ( "/embed" , post ( embed) )
@@ -1686,74 +1684,72 @@ pub async fn run(
16861684 . route ( "/embeddings" , post ( openai_embed) )
16871685 . route ( "/v1/embeddings" , post ( openai_embed) )
16881686 // Vertex compat route
1689- . route ( "/vertex" , post ( vertex_compatibility) )
1687+ . route ( "/vertex" , post ( vertex_compatibility) ) ;
1688+
1689+ #[ allow( unused_mut) ]
1690+ let mut public_routes = Router :: new ( )
16901691 // Base Health route
16911692 . route ( "/health" , get ( health) )
16921693 // Inference API health route
16931694 . route ( "/" , get ( health) )
16941695 // AWS Sagemaker health route
16951696 . route ( "/ping" , get ( health) )
16961697 // Prometheus metrics route
1697- . route ( "/metrics" , get ( metrics) )
1698- // Update payload limit
1699- . layer ( DefaultBodyLimit :: max ( payload_limit) ) ;
1698+ . route ( "/metrics" , get ( metrics) ) ;
17001699
17011700 #[ cfg( feature = "google" ) ]
17021701 {
17031702 tracing:: info!( "Built with `google` feature" ) ;
17041703
17051704 if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
17061705 tracing:: info!( "Serving Vertex compatible route on {env_predict_route}" ) ;
1707- app = app . route ( & env_predict_route, post ( vertex_compatibility) ) ;
1706+ routes = routes . route ( & env_predict_route, post ( vertex_compatibility) ) ;
17081707 }
17091708
17101709 if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
17111710 tracing:: info!( "Serving Vertex compatible health route on {env_health_route}" ) ;
1712- app = app . route ( & env_health_route, get ( health) ) ;
1711+ public_routes = public_routes . route ( & env_health_route, get ( health) ) ;
17131712 }
17141713 }
17151714 #[ cfg( not( feature = "google" ) ) ]
17161715 {
17171716 // Set default routes
1718- app = match & info. model_type {
1717+ routes = match & info. model_type {
17191718 ModelType :: Classifier ( _) => {
1720- app. route ( "/" , post ( predict) )
1719+ routes
1720+ . route ( "/" , post ( predict) )
17211721 // AWS Sagemaker route
17221722 . route ( "/invocations" , post ( predict) )
17231723 }
17241724 ModelType :: Reranker ( _) => {
1725- app. route ( "/" , post ( rerank) )
1725+ routes
1726+ . route ( "/" , post ( rerank) )
17261727 // AWS Sagemaker route
17271728 . route ( "/invocations" , post ( rerank) )
17281729 }
17291730 ModelType :: Embedding ( model) => {
17301731 if std:: env:: var ( "TASK" ) . ok ( ) == Some ( "sentence-similarity" . to_string ( ) ) {
1731- app. route ( "/" , post ( similarity) )
1732+ routes
1733+ . route ( "/" , post ( similarity) )
17321734 // AWS Sagemaker route
17331735 . route ( "/invocations" , post ( similarity) )
17341736 } else if model. pooling == "splade" {
1735- app. route ( "/" , post ( embed_sparse) )
1737+ routes
1738+ . route ( "/" , post ( embed_sparse) )
17361739 // AWS Sagemaker route
17371740 . route ( "/invocations" , post ( embed_sparse) )
17381741 } else {
1739- app. route ( "/" , post ( embed) )
1742+ routes
1743+ . route ( "/" , post ( embed) )
17401744 // AWS Sagemaker route
17411745 . route ( "/invocations" , post ( embed) )
17421746 }
17431747 }
17441748 } ;
17451749 }
17461750
1747- app = app
1748- . layer ( Extension ( infer) )
1749- . layer ( Extension ( info) )
1750- . layer ( Extension ( prom_handle. clone ( ) ) )
1751- . layer ( OtelAxumLayer :: default ( ) )
1752- . layer ( cors_layer) ;
1753-
17541751 if let Some ( api_key) = api_key {
1755- let mut prefix = "Bearer " . to_string ( ) ;
1756- prefix. push_str ( & api_key) ;
1752+ let prefix = format ! ( "Bearer {}" , api_key) ;
17571753
17581754 // Leak to allow FnMut
17591755 let api_key: & ' static str = prefix. leak ( ) ;
@@ -1770,9 +1766,20 @@ pub async fn run(
17701766 }
17711767 } ;
17721768
1773- app = app . layer ( axum:: middleware:: from_fn ( auth) ) ;
1769+ routes = routes . layer ( axum:: middleware:: from_fn ( auth) ) ;
17741770 }
17751771
1772+ let app = Router :: new ( )
1773+ . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc) )
1774+ . merge ( routes)
1775+ . merge ( public_routes)
1776+ . layer ( Extension ( infer) )
1777+ . layer ( Extension ( info) )
1778+ . layer ( Extension ( prom_handle. clone ( ) ) )
1779+ . layer ( OtelAxumLayer :: default ( ) )
1780+ . layer ( DefaultBodyLimit :: max ( payload_limit) )
1781+ . layer ( cors_layer) ;
1782+
17761783 // Run server
17771784 let listener = tokio:: net:: TcpListener :: bind ( & addr)
17781785 . await
0 commit comments