Skip to content

Commit a319de6

Browse files
committed
refactor: replace unwrap in model_manager with explicit error handling
1 parent 05913af commit a319de6

File tree

4 files changed

+122
-60
lines changed

4 files changed

+122
-60
lines changed

lib/llm/src/discovery/model_manager.rs

Lines changed: 105 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4-
use dynamo_runtime::component::Component;
5-
use dynamo_runtime::prelude::DistributedRuntimeProvider;
6-
use dynamo_runtime::slug::Slug;
4+
use std::{
5+
collections::{HashMap, HashSet},
6+
sync::{Arc, Mutex, RwLock},
7+
};
78

8-
use crate::discovery::ModelEntry;
9+
use anyhow::Context;
10+
use dynamo_runtime::{component::Component, prelude::DistributedRuntimeProvider, slug::Slug};
911

10-
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
1112
use crate::{
12-
kv_router::KvRouter,
13+
discovery::ModelEntry,
14+
kv_router::{KvRouter, KvRouterConfig, scheduler::DefaultWorkerSelector},
1315
types::openai::{
1416
chat_completions::OpenAIChatCompletionsStreamingEngine,
1517
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
1618
},
1719
};
18-
use std::collections::HashSet;
19-
use std::sync::RwLock;
20-
use std::{
21-
collections::HashMap,
22-
sync::{Arc, Mutex},
23-
};
2420

2521
#[derive(Debug, thiserror::Error)]
2622
pub enum ModelManagerError {
@@ -29,6 +25,9 @@ pub enum ModelManagerError {
2925

3026
#[error("Model already exists: {0}")]
3127
ModelAlreadyExists(String),
28+
29+
#[error("Lock poisoned: {0}")]
30+
LockPoisoned(&'static str),
3231
}
3332

3433
// Don't implement Clone for this, put it in an Arc instead.
@@ -60,41 +59,69 @@ impl ModelManager {
6059
}
6160
}
6261

63-
pub fn get_model_entries(&self) -> Vec<ModelEntry> {
64-
self.entries.lock().unwrap().values().cloned().collect()
62+
pub fn get_model_entries(&self) -> Result<Vec<ModelEntry>, ModelManagerError> {
63+
let guard = self
64+
.entries
65+
.lock()
66+
.map_err(|_| ModelManagerError::LockPoisoned("entries"))?;
67+
Ok(guard.values().cloned().collect())
6568
}
6669

67-
pub fn has_model_any(&self, model: &str) -> bool {
68-
self.chat_completion_engines.read().unwrap().contains(model)
69-
|| self.completion_engines.read().unwrap().contains(model)
70+
pub fn has_model_any(&self, model: &str) -> Result<bool, ModelManagerError> {
71+
let chat = self
72+
.chat_completion_engines
73+
.read()
74+
.map_err(|_| ModelManagerError::LockPoisoned("chat_completion_engines"))?;
75+
if chat.contains(model) {
76+
return Ok(true);
77+
}
78+
let comp = self
79+
.completion_engines
80+
.read()
81+
.map_err(|_| ModelManagerError::LockPoisoned("completion_engines"))?;
82+
Ok(comp.contains(model))
7083
}
7184

72-
pub fn model_display_names(&self) -> HashSet<String> {
73-
self.list_chat_completions_models()
74-
.into_iter()
75-
.chain(self.list_completions_models())
76-
.chain(self.list_embeddings_models())
77-
.collect()
85+
pub fn model_display_names(&self) -> Result<HashSet<String>, ModelManagerError> {
86+
let chat = self.list_chat_completions_models()?;
87+
let comp = self.list_completions_models()?;
88+
let embed = self.list_embeddings_models()?;
89+
Ok(chat.into_iter().chain(comp).chain(embed).collect())
7890
}
7991

80-
pub fn list_chat_completions_models(&self) -> Vec<String> {
81-
self.chat_completion_engines.read().unwrap().list()
92+
pub fn list_chat_completions_models(&self) -> Result<Vec<String>, ModelManagerError> {
93+
let guard = self
94+
.chat_completion_engines
95+
.read()
96+
.map_err(|_| ModelManagerError::LockPoisoned("chat_completion_engines"))?;
97+
Ok(guard.list())
8298
}
8399

84-
pub fn list_completions_models(&self) -> Vec<String> {
85-
self.completion_engines.read().unwrap().list()
100+
pub fn list_completions_models(&self) -> Result<Vec<String>, ModelManagerError> {
101+
let guard = self
102+
.completion_engines
103+
.read()
104+
.map_err(|_| ModelManagerError::LockPoisoned("completion_engines"))?;
105+
Ok(guard.list())
86106
}
87107

88-
pub fn list_embeddings_models(&self) -> Vec<String> {
89-
self.embeddings_engines.read().unwrap().list()
108+
pub fn list_embeddings_models(&self) -> Result<Vec<String>, ModelManagerError> {
109+
let guard = self
110+
.embeddings_engines
111+
.read()
112+
.map_err(|_| ModelManagerError::LockPoisoned("embeddings_engines"))?;
113+
Ok(guard.list())
90114
}
91115

92116
pub fn add_completions_model(
93117
&self,
94118
model: &str,
95119
engine: OpenAICompletionsStreamingEngine,
96120
) -> Result<(), ModelManagerError> {
97-
let mut clients = self.completion_engines.write().unwrap();
121+
let mut clients = self
122+
.completion_engines
123+
.write()
124+
.map_err(|_| ModelManagerError::LockPoisoned("completion_engines"))?;
98125
clients.add(model, engine)
99126
}
100127

@@ -103,7 +130,10 @@ impl ModelManager {
103130
model: &str,
104131
engine: OpenAIChatCompletionsStreamingEngine,
105132
) -> Result<(), ModelManagerError> {
106-
let mut clients = self.chat_completion_engines.write().unwrap();
133+
let mut clients = self
134+
.chat_completion_engines
135+
.write()
136+
.map_err(|_| ModelManagerError::LockPoisoned("chat_completion_engines"))?;
107137
clients.add(model, engine)
108138
}
109139

@@ -112,22 +142,34 @@ impl ModelManager {
112142
model: &str,
113143
engine: OpenAIEmbeddingsStreamingEngine,
114144
) -> Result<(), ModelManagerError> {
115-
let mut clients = self.embeddings_engines.write().unwrap();
145+
let mut clients = self
146+
.embeddings_engines
147+
.write()
148+
.map_err(|_| ModelManagerError::LockPoisoned("embeddings_engines"))?;
116149
clients.add(model, engine)
117150
}
118151

119152
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
120-
let mut clients = self.completion_engines.write().unwrap();
153+
let mut clients = self
154+
.completion_engines
155+
.write()
156+
.map_err(|_| ModelManagerError::LockPoisoned("completion_engines"))?;
121157
clients.remove(model)
122158
}
123159

124160
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
125-
let mut clients = self.chat_completion_engines.write().unwrap();
161+
let mut clients = self
162+
.chat_completion_engines
163+
.write()
164+
.map_err(|_| ModelManagerError::LockPoisoned("chat_completion_engines"))?;
126165
clients.remove(model)
127166
}
128167

129168
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
130-
let mut clients = self.embeddings_engines.write().unwrap();
169+
let mut clients = self
170+
.embeddings_engines
171+
.write()
172+
.map_err(|_| ModelManagerError::LockPoisoned("embeddings_engines"))?;
131173
clients.remove(model)
132174
}
133175

@@ -137,7 +179,7 @@ impl ModelManager {
137179
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
138180
self.embeddings_engines
139181
.read()
140-
.unwrap()
182+
.map_err(|_| ModelManagerError::LockPoisoned("embeddings_engines"))?
141183
.get(model)
142184
.cloned()
143185
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
@@ -149,7 +191,7 @@ impl ModelManager {
149191
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
150192
self.completion_engines
151193
.read()
152-
.unwrap()
194+
.map_err(|_| ModelManagerError::LockPoisoned("completion_engines"))?
153195
.get(model)
154196
.cloned()
155197
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
@@ -161,21 +203,30 @@ impl ModelManager {
161203
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
162204
self.chat_completion_engines
163205
.read()
164-
.unwrap()
206+
.map_err(|_| ModelManagerError::LockPoisoned("chat_completion_engines"))?
165207
.get(model)
166208
.cloned()
167209
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
168210
}
169211

170212
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
171213
/// deleted from etcd.
172-
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
173-
self.entries.lock().unwrap().insert(key.to_string(), entry);
214+
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) -> Result<(), ModelManagerError> {
215+
let mut guard = self
216+
.entries
217+
.lock()
218+
.map_err(|_| ModelManagerError::LockPoisoned("entries"))?;
219+
guard.insert(key.to_string(), entry);
220+
Ok(())
174221
}
175222

176223
/// Remove and return model entry for this instance's etcd key. We do this when the instance stops.
177-
pub fn remove_model_entry(&self, key: &str) -> Option<ModelEntry> {
178-
self.entries.lock().unwrap().remove(key)
224+
pub fn remove_model_entry(&self, key: &str) -> Result<Option<ModelEntry>, ModelManagerError> {
225+
let mut guard = self
226+
.entries
227+
.lock()
228+
.map_err(|_| ModelManagerError::LockPoisoned("entries"))?;
229+
Ok(guard.remove(key))
179230
}
180231

181232
pub async fn kv_chooser_for(
@@ -185,7 +236,10 @@ impl ModelManager {
185236
kv_cache_block_size: u32,
186237
kv_router_config: Option<KvRouterConfig>,
187238
) -> anyhow::Result<Arc<KvRouter>> {
188-
if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
239+
if let Some(kv_chooser) = self
240+
.get_kv_chooser(model_name)
241+
.map_err(|e| anyhow::anyhow!(e.to_string()))?
242+
{
189243
// Check if the existing router has a different block size
190244
if kv_chooser.block_size() != kv_cache_block_size {
191245
tracing::warn!(
@@ -202,8 +256,12 @@ impl ModelManager {
202256
.await
203257
}
204258

205-
fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
206-
self.kv_choosers.lock().unwrap().get(model_name).cloned()
259+
fn get_kv_chooser(&self, model_name: &str) -> Result<Option<Arc<KvRouter>>, ModelManagerError> {
260+
let guard = self
261+
.kv_choosers
262+
.lock()
263+
.map_err(|_| ModelManagerError::LockPoisoned("kv_choosers"))?;
264+
Ok(guard.get(model_name).cloned())
207265
}
208266

209267
/// Create and return a KV chooser for this component and model
@@ -242,7 +300,8 @@ impl ModelManager {
242300
let new_kv_chooser = Arc::new(chooser);
243301
self.kv_choosers
244302
.lock()
245-
.unwrap()
303+
.map_err(|_| ModelManagerError::LockPoisoned("kv_choosers"))
304+
.context("failed to acquire kv_choosers lock for insert")?
246305
.insert(model_name.to_string(), new_kv_chooser.clone());
247306
Ok(new_kv_chooser)
248307
}

lib/llm/src/discovery/watcher.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl ModelWatcher {
8383
pub async fn wait_for_chat_model(&self) -> String {
8484
// Loop in case it gets added and immediately deleted
8585
loop {
86-
if let Some(model_name) = self.manager.list_chat_completions_models().first() {
86+
if let Some(model_name) = self.manager.list_chat_completions_models().unwrap().first() {
8787
return model_name.to_owned();
8888
}
8989
self.notify_on_model.notified().await
@@ -117,15 +117,15 @@ impl ModelWatcher {
117117
continue;
118118
}
119119
};
120-
self.manager.save_model_entry(key, model_entry.clone());
120+
let _ = self.manager.save_model_entry(key, model_entry.clone());
121121

122122
if let Some(tx) = &self.model_update_tx {
123123
tx.send(ModelUpdate::Added(model_entry.model_type))
124124
.await
125125
.ok();
126126
}
127127

128-
if self.manager.has_model_any(&model_entry.name) {
128+
if self.manager.has_model_any(&model_entry.name).unwrap() {
129129
tracing::trace!(name = model_entry.name, "New endpoint for existing model");
130130
self.notify_on_model.notify_waiters();
131131
continue;
@@ -164,7 +164,7 @@ impl ModelWatcher {
164164
/// Returns the name of the model we just deleted, if any.
165165
async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
166166
let key = kv.key_str()?;
167-
let model_entry = match self.manager.remove_model_entry(key) {
167+
let model_entry = match self.manager.remove_model_entry(key)? {
168168
Some(entry) => entry,
169169
None => {
170170
anyhow::bail!("Missing ModelEntry for {key}");
@@ -179,26 +179,26 @@ impl ModelWatcher {
179179
let mut update_tx = true;
180180
let mut model_type: ModelType = model_entry.model_type;
181181
if model_entry.model_type == ModelType::Chat
182-
&& self.manager.list_chat_completions_models().is_empty()
182+
&& self.manager.list_chat_completions_models()?.is_empty()
183183
{
184184
self.manager.remove_chat_completions_model(&model_name).ok();
185185
model_type = ModelType::Chat;
186186
} else if model_entry.model_type == ModelType::Completion
187-
&& self.manager.list_completions_models().is_empty()
187+
&& self.manager.list_completions_models()?.is_empty()
188188
{
189189
self.manager.remove_completions_model(&model_name).ok();
190190
model_type = ModelType::Completion;
191191
} else if model_entry.model_type == ModelType::Embedding
192-
&& self.manager.list_embeddings_models().is_empty()
192+
&& self.manager.list_embeddings_models()?.is_empty()
193193
{
194194
self.manager.remove_embeddings_model(&model_name).ok();
195195
model_type = ModelType::Embedding;
196196
} else if model_entry.model_type == ModelType::Backend {
197-
if self.manager.list_chat_completions_models().is_empty() {
197+
if self.manager.list_chat_completions_models()?.is_empty() {
198198
self.manager.remove_chat_completions_model(&model_name).ok();
199199
model_type = ModelType::Chat;
200200
}
201-
if self.manager.list_completions_models().is_empty() {
201+
if self.manager.list_completions_models()?.is_empty() {
202202
self.manager.remove_completions_model(&model_name).ok();
203203
if model_type == ModelType::Chat {
204204
model_type = ModelType::Backend;
@@ -228,14 +228,17 @@ impl ModelWatcher {
228228
let mut completions_model_removed = false;
229229
let mut embeddings_model_removed = false;
230230

231-
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
231+
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models()?.is_empty()
232+
{
232233
chat_model_removed = true;
233234
}
234-
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
235+
if completions_model_remove_err.is_ok()
236+
&& self.manager.list_completions_models()?.is_empty()
235237
{
236238
completions_model_removed = true;
237239
}
238-
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
240+
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models()?.is_empty()
241+
{
239242
embeddings_model_removed = true;
240243
}
241244

lib/llm/src/http/service/health.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async fn live_handler(
5252
async fn health_handler(
5353
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
5454
) -> impl IntoResponse {
55-
let model_entries = state.manager().get_model_entries();
55+
let model_entries = state.manager().get_model_entries().unwrap();
5656
let instances = if let Some(etcd_client) = state.etcd_client() {
5757
match list_all_instances(etcd_client).await {
5858
Ok(instances) => instances,

lib/llm/src/http/service/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ async fn list_models_openai(
927927
.as_secs();
928928
let mut data = Vec::new();
929929

930-
let models: HashSet<String> = state.manager().model_display_names();
930+
let models: HashSet<String> = state.manager().model_display_names().unwrap();
931931
for model_name in models {
932932
data.push(ModelListing {
933933
id: model_name.clone(),

0 commit comments

Comments
 (0)