Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

use std::{
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
sync::Arc,
};

use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};

use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
Expand Down Expand Up @@ -64,8 +64,8 @@ impl ModelManager {
}

pub fn has_model_any(&self, model: &str) -> bool {
self.chat_completion_engines.read().unwrap().contains(model)
|| self.completion_engines.read().unwrap().contains(model)
self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().contains(model)
}

pub fn model_display_names(&self) -> HashSet<String> {
Expand All @@ -77,23 +77,23 @@ impl ModelManager {
}

pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list()
self.chat_completion_engines.read().list()
}

pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().unwrap().list()
self.completion_engines.read().list()
}

pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().unwrap().list()
self.embeddings_engines.read().list()
}

pub fn add_completions_model(
&self,
model: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
let mut clients = self.completion_engines.write();
clients.add(model, engine)
}

Expand All @@ -102,7 +102,7 @@ impl ModelManager {
model: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
let mut clients = self.chat_completion_engines.write();
clients.add(model, engine)
}

Expand All @@ -111,22 +111,22 @@ impl ModelManager {
model: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
let mut clients = self.embeddings_engines.write();
clients.add(model, engine)
}

pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
let mut clients = self.completion_engines.write();
clients.remove(model)
}

pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
let mut clients = self.chat_completion_engines.write();
clients.remove(model)
}

pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
let mut clients = self.embeddings_engines.write();
clients.remove(model)
}

Expand All @@ -136,7 +136,6 @@ impl ModelManager {
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
Expand All @@ -148,7 +147,6 @@ impl ModelManager {
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
Expand All @@ -160,7 +158,6 @@ impl ModelManager {
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.chat_completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
Expand Down
Loading