@@ -15,12 +15,33 @@ use tokio::sync::Mutex;
1515
1616use dynamo_runtime:: {
1717 self as rs, logging,
18- pipeline:: { EngineStream , ManyOut , SingleIn } ,
18+ pipeline:: {
19+ network:: egress:: push_router:: RouterMode as RsRouterMode , EngineStream , ManyOut , SingleIn ,
20+ } ,
1921 protocols:: annotated:: Annotated as RsAnnotated ,
2022 traits:: DistributedRuntimeProvider ,
2123} ;
2224
2325use dynamo_llm:: { self as llm_rs} ;
26+ use dynamo_llm:: { entrypoint:: RouterConfig , kv_router:: KvRouterConfig } ;
27+
28+ #[ pyclass( eq, eq_int) ]
29+ #[ derive( Clone , Debug , PartialEq ) ]
30+ pub enum RouterMode {
31+ RoundRobin ,
32+ Random ,
33+ KV ,
34+ }
35+
36+ impl From < RouterMode > for RsRouterMode {
37+ fn from ( mode : RouterMode ) -> Self {
38+ match mode {
39+ RouterMode :: RoundRobin => Self :: RoundRobin ,
40+ RouterMode :: Random => Self :: Random ,
41+ RouterMode :: KV => Self :: KV ,
42+ }
43+ }
44+ }
2445
2546mod engine;
2647mod http;
@@ -75,6 +96,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
7596 m. add_class :: < http:: HttpAsyncEngine > ( ) ?;
7697 m. add_class :: < EtcdKvCache > ( ) ?;
7798 m. add_class :: < ModelType > ( ) ?;
99+ m. add_class :: < RouterMode > ( ) ?;
78100
79101 engine:: add_to_module ( m) ?;
80102
@@ -99,7 +121,8 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
99121}
100122
101123#[ pyfunction]
102- #[ pyo3( signature = ( model_type, endpoint, model_path, model_name=None , context_length=None , kv_cache_block_size=None ) ) ]
124+ #[ pyo3( signature = ( model_type, endpoint, model_path, model_name=None , context_length=None , kv_cache_block_size=None , router_mode=None ) ) ]
125+ #[ allow( clippy:: too_many_arguments) ]
103126fn register_llm < ' p > (
104127 py : Python < ' p > ,
105128 model_type : ModelType ,
@@ -108,6 +131,7 @@ fn register_llm<'p>(
108131 model_name : Option < & str > ,
109132 context_length : Option < u32 > ,
110133 kv_cache_block_size : Option < u32 > ,
134+ router_mode : Option < RouterMode > ,
111135) -> PyResult < Bound < ' p , PyAny > > {
112136 let model_type_obj = match model_type {
113137 ModelType :: Chat => llm_rs:: model_type:: ModelType :: Chat ,
@@ -118,13 +142,17 @@ fn register_llm<'p>(
118142
119143 let inner_path = model_path. to_string ( ) ;
120144 let model_name = model_name. map ( |n| n. to_string ( ) ) ;
145+ let router_mode = router_mode. unwrap_or ( RouterMode :: RoundRobin ) ;
146+ let router_config = RouterConfig :: new ( router_mode. into ( ) , KvRouterConfig :: default ( ) ) ;
147+
121148 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
122149 let mut builder = dynamo_llm:: local_model:: LocalModelBuilder :: default ( ) ;
123150 builder
124151 . model_path ( Some ( PathBuf :: from ( inner_path) ) )
125152 . model_name ( model_name)
126153 . context_length ( context_length)
127- . kv_cache_block_size ( kv_cache_block_size) ;
154+ . kv_cache_block_size ( kv_cache_block_size)
155+ . router_config ( router_config) ;
128156 // Download from HF, load the ModelDeploymentCard
129157 let mut local_model = builder. build ( ) . await . map_err ( to_pyerr) ?;
130158 // Advertise ourself on etcd so ingress can find us
0 commit comments