11// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// SPDX-License-Identifier: Apache-2.0
33
4- use dynamo_runtime:: component:: Component ;
5- use dynamo_runtime:: prelude:: DistributedRuntimeProvider ;
6- use dynamo_runtime:: slug:: Slug ;
4+ use std:: {
5+ collections:: { HashMap , HashSet } ,
6+ sync:: { Arc , Mutex , RwLock } ,
7+ } ;
78
8- use crate :: discovery:: ModelEntry ;
9+ use anyhow:: Context ;
10+ use dynamo_runtime:: { component:: Component , prelude:: DistributedRuntimeProvider , slug:: Slug } ;
911
10- use crate :: kv_router:: { KvRouterConfig , scheduler:: DefaultWorkerSelector } ;
1112use crate :: {
12- kv_router:: KvRouter ,
13+ discovery:: ModelEntry ,
14+ kv_router:: { KvRouter , KvRouterConfig , scheduler:: DefaultWorkerSelector } ,
1315 types:: openai:: {
1416 chat_completions:: OpenAIChatCompletionsStreamingEngine ,
1517 completions:: OpenAICompletionsStreamingEngine , embeddings:: OpenAIEmbeddingsStreamingEngine ,
1618 } ,
1719} ;
18- use std:: collections:: HashSet ;
19- use std:: sync:: RwLock ;
20- use std:: {
21- collections:: HashMap ,
22- sync:: { Arc , Mutex } ,
23- } ;
2420
2521#[ derive( Debug , thiserror:: Error ) ]
2622pub enum ModelManagerError {
@@ -29,6 +25,9 @@ pub enum ModelManagerError {
2925
3026 #[ error( "Model already exists: {0}" ) ]
3127 ModelAlreadyExists ( String ) ,
28+
29+ #[ error( "Lock poisoned: {0}" ) ]
30+ LockPoisoned ( & ' static str ) ,
3231}
3332
3433// Don't implement Clone for this, put it in an Arc instead.
@@ -60,41 +59,69 @@ impl ModelManager {
6059 }
6160 }
6261
63- pub fn get_model_entries ( & self ) -> Vec < ModelEntry > {
64- self . entries . lock ( ) . unwrap ( ) . values ( ) . cloned ( ) . collect ( )
62+ pub fn get_model_entries ( & self ) -> Result < Vec < ModelEntry > , ModelManagerError > {
63+ let guard = self
64+ . entries
65+ . lock ( )
66+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "entries" ) ) ?;
67+ Ok ( guard. values ( ) . cloned ( ) . collect ( ) )
6568 }
6669
67- pub fn has_model_any ( & self , model : & str ) -> bool {
68- self . chat_completion_engines . read ( ) . unwrap ( ) . contains ( model)
69- || self . completion_engines . read ( ) . unwrap ( ) . contains ( model)
70+ pub fn has_model_any ( & self , model : & str ) -> Result < bool , ModelManagerError > {
71+ let chat = self
72+ . chat_completion_engines
73+ . read ( )
74+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "chat_completion_engines" ) ) ?;
75+ if chat. contains ( model) {
76+ return Ok ( true ) ;
77+ }
78+ let comp = self
79+ . completion_engines
80+ . read ( )
81+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "completion_engines" ) ) ?;
82+ Ok ( comp. contains ( model) )
7083 }
7184
72- pub fn model_display_names ( & self ) -> HashSet < String > {
73- self . list_chat_completions_models ( )
74- . into_iter ( )
75- . chain ( self . list_completions_models ( ) )
76- . chain ( self . list_embeddings_models ( ) )
77- . collect ( )
85+ pub fn model_display_names ( & self ) -> Result < HashSet < String > , ModelManagerError > {
86+ let chat = self . list_chat_completions_models ( ) ?;
87+ let comp = self . list_completions_models ( ) ?;
88+ let embed = self . list_embeddings_models ( ) ?;
89+ Ok ( chat. into_iter ( ) . chain ( comp) . chain ( embed) . collect ( ) )
7890 }
7991
80- pub fn list_chat_completions_models ( & self ) -> Vec < String > {
81- self . chat_completion_engines . read ( ) . unwrap ( ) . list ( )
92+ pub fn list_chat_completions_models ( & self ) -> Result < Vec < String > , ModelManagerError > {
93+ let guard = self
94+ . chat_completion_engines
95+ . read ( )
96+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "chat_completion_engines" ) ) ?;
97+ Ok ( guard. list ( ) )
8298 }
8399
84- pub fn list_completions_models ( & self ) -> Vec < String > {
85- self . completion_engines . read ( ) . unwrap ( ) . list ( )
100+ pub fn list_completions_models ( & self ) -> Result < Vec < String > , ModelManagerError > {
101+ let guard = self
102+ . completion_engines
103+ . read ( )
104+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "completion_engines" ) ) ?;
105+ Ok ( guard. list ( ) )
86106 }
87107
88- pub fn list_embeddings_models ( & self ) -> Vec < String > {
89- self . embeddings_engines . read ( ) . unwrap ( ) . list ( )
108+ pub fn list_embeddings_models ( & self ) -> Result < Vec < String > , ModelManagerError > {
109+ let guard = self
110+ . embeddings_engines
111+ . read ( )
112+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "embeddings_engines" ) ) ?;
113+ Ok ( guard. list ( ) )
90114 }
91115
92116 pub fn add_completions_model (
93117 & self ,
94118 model : & str ,
95119 engine : OpenAICompletionsStreamingEngine ,
96120 ) -> Result < ( ) , ModelManagerError > {
97- let mut clients = self . completion_engines . write ( ) . unwrap ( ) ;
121+ let mut clients = self
122+ . completion_engines
123+ . write ( )
124+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "completion_engines" ) ) ?;
98125 clients. add ( model, engine)
99126 }
100127
@@ -103,7 +130,10 @@ impl ModelManager {
103130 model : & str ,
104131 engine : OpenAIChatCompletionsStreamingEngine ,
105132 ) -> Result < ( ) , ModelManagerError > {
106- let mut clients = self . chat_completion_engines . write ( ) . unwrap ( ) ;
133+ let mut clients = self
134+ . chat_completion_engines
135+ . write ( )
136+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "chat_completion_engines" ) ) ?;
107137 clients. add ( model, engine)
108138 }
109139
@@ -112,22 +142,34 @@ impl ModelManager {
112142 model : & str ,
113143 engine : OpenAIEmbeddingsStreamingEngine ,
114144 ) -> Result < ( ) , ModelManagerError > {
115- let mut clients = self . embeddings_engines . write ( ) . unwrap ( ) ;
145+ let mut clients = self
146+ . embeddings_engines
147+ . write ( )
148+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "embeddings_engines" ) ) ?;
116149 clients. add ( model, engine)
117150 }
118151
119152 pub fn remove_completions_model ( & self , model : & str ) -> Result < ( ) , ModelManagerError > {
120- let mut clients = self . completion_engines . write ( ) . unwrap ( ) ;
153+ let mut clients = self
154+ . completion_engines
155+ . write ( )
156+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "completion_engines" ) ) ?;
121157 clients. remove ( model)
122158 }
123159
124160 pub fn remove_chat_completions_model ( & self , model : & str ) -> Result < ( ) , ModelManagerError > {
125- let mut clients = self . chat_completion_engines . write ( ) . unwrap ( ) ;
161+ let mut clients = self
162+ . chat_completion_engines
163+ . write ( )
164+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "chat_completion_engines" ) ) ?;
126165 clients. remove ( model)
127166 }
128167
129168 pub fn remove_embeddings_model ( & self , model : & str ) -> Result < ( ) , ModelManagerError > {
130- let mut clients = self . embeddings_engines . write ( ) . unwrap ( ) ;
169+ let mut clients = self
170+ . embeddings_engines
171+ . write ( )
172+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "embeddings_engines" ) ) ?;
131173 clients. remove ( model)
132174 }
133175
@@ -137,7 +179,7 @@ impl ModelManager {
137179 ) -> Result < OpenAIEmbeddingsStreamingEngine , ModelManagerError > {
138180 self . embeddings_engines
139181 . read ( )
140- . unwrap ( )
182+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "embeddings_engines" ) ) ?
141183 . get ( model)
142184 . cloned ( )
143185 . ok_or ( ModelManagerError :: ModelNotFound ( model. to_string ( ) ) )
@@ -149,7 +191,7 @@ impl ModelManager {
149191 ) -> Result < OpenAICompletionsStreamingEngine , ModelManagerError > {
150192 self . completion_engines
151193 . read ( )
152- . unwrap ( )
194+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "completion_engines" ) ) ?
153195 . get ( model)
154196 . cloned ( )
155197 . ok_or ( ModelManagerError :: ModelNotFound ( model. to_string ( ) ) )
@@ -161,21 +203,30 @@ impl ModelManager {
161203 ) -> Result < OpenAIChatCompletionsStreamingEngine , ModelManagerError > {
162204 self . chat_completion_engines
163205 . read ( )
164- . unwrap ( )
206+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "chat_completion_engines" ) ) ?
165207 . get ( model)
166208 . cloned ( )
167209 . ok_or ( ModelManagerError :: ModelNotFound ( model. to_string ( ) ) )
168210 }
169211
170212 /// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
171213 /// deleted from etcd.
172- pub fn save_model_entry ( & self , key : & str , entry : ModelEntry ) {
173- self . entries . lock ( ) . unwrap ( ) . insert ( key. to_string ( ) , entry) ;
214+ pub fn save_model_entry ( & self , key : & str , entry : ModelEntry ) -> Result < ( ) , ModelManagerError > {
215+ let mut guard = self
216+ . entries
217+ . lock ( )
218+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "entries" ) ) ?;
219+ guard. insert ( key. to_string ( ) , entry) ;
220+ Ok ( ( ) )
174221 }
175222
176223 /// Remove and return model entry for this instance's etcd key. We do this when the instance stops.
177- pub fn remove_model_entry ( & self , key : & str ) -> Option < ModelEntry > {
178- self . entries . lock ( ) . unwrap ( ) . remove ( key)
224+ pub fn remove_model_entry ( & self , key : & str ) -> Result < Option < ModelEntry > , ModelManagerError > {
225+ let mut guard = self
226+ . entries
227+ . lock ( )
228+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "entries" ) ) ?;
229+ Ok ( guard. remove ( key) )
179230 }
180231
181232 pub async fn kv_chooser_for (
@@ -185,7 +236,10 @@ impl ModelManager {
185236 kv_cache_block_size : u32 ,
186237 kv_router_config : Option < KvRouterConfig > ,
187238 ) -> anyhow:: Result < Arc < KvRouter > > {
188- if let Some ( kv_chooser) = self . get_kv_chooser ( model_name) {
239+ if let Some ( kv_chooser) = self
240+ . get_kv_chooser ( model_name)
241+ . map_err ( |e| anyhow:: anyhow!( e. to_string( ) ) ) ?
242+ {
189243 // Check if the existing router has a different block size
190244 if kv_chooser. block_size ( ) != kv_cache_block_size {
191245 tracing:: warn!(
@@ -202,8 +256,12 @@ impl ModelManager {
202256 . await
203257 }
204258
205- fn get_kv_chooser ( & self , model_name : & str ) -> Option < Arc < KvRouter > > {
206- self . kv_choosers . lock ( ) . unwrap ( ) . get ( model_name) . cloned ( )
259+ fn get_kv_chooser ( & self , model_name : & str ) -> Result < Option < Arc < KvRouter > > , ModelManagerError > {
260+ let guard = self
261+ . kv_choosers
262+ . lock ( )
263+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "kv_choosers" ) ) ?;
264+ Ok ( guard. get ( model_name) . cloned ( ) )
207265 }
208266
209267 /// Create and return a KV chooser for this component and model
@@ -242,7 +300,8 @@ impl ModelManager {
242300 let new_kv_chooser = Arc :: new ( chooser) ;
243301 self . kv_choosers
244302 . lock ( )
245- . unwrap ( )
303+ . map_err ( |_| ModelManagerError :: LockPoisoned ( "kv_choosers" ) )
304+ . context ( "failed to acquire kv_choosers lock for insert" ) ?
246305 . insert ( model_name. to_string ( ) , new_kv_chooser. clone ( ) ) ;
247306 Ok ( new_kv_chooser)
248307 }
0 commit comments