@@ -33,6 +33,7 @@ use crate::{
3333 completions:: { NvCreateCompletionRequest , NvCreateCompletionResponse } ,
3434 embeddings:: { NvCreateEmbeddingRequest , NvCreateEmbeddingResponse } ,
3535 } ,
36+ tensor:: { NvCreateTensorRequest , NvCreateTensorResponse } ,
3637 } ,
3738} ;
3839
@@ -59,6 +60,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
5960 ModelType :: Chat ,
6061 ModelType :: Completions ,
6162 ModelType :: Embedding ,
63+ ModelType :: TensorBased ,
6264] ;
6365
6466impl ModelWatcher {
@@ -213,10 +215,12 @@ impl ModelWatcher {
213215 let chat_model_remove_err = self . manager . remove_chat_completions_model ( & model_name) ;
214216 let completions_model_remove_err = self . manager . remove_completions_model ( & model_name) ;
215217 let embeddings_model_remove_err = self . manager . remove_embeddings_model ( & model_name) ;
218+ let tensor_model_remove_err = self . manager . remove_tensor_model ( & model_name) ;
216219
217220 let mut chat_model_removed = false ;
218221 let mut completions_model_removed = false ;
219222 let mut embeddings_model_removed = false ;
223+ let mut tensor_model_removed = false ;
220224
221225 if chat_model_remove_err. is_ok ( ) && self . manager . list_chat_completions_models ( ) . is_empty ( ) {
222226 chat_model_removed = true ;
@@ -228,20 +232,29 @@ impl ModelWatcher {
228232 if embeddings_model_remove_err. is_ok ( ) && self . manager . list_embeddings_models ( ) . is_empty ( ) {
229233 embeddings_model_removed = true ;
230234 }
235+ if tensor_model_remove_err. is_ok ( ) && self . manager . list_tensor_models ( ) . is_empty ( ) {
236+ tensor_model_removed = true ;
237+ }
231238
232- if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
239+ if !chat_model_removed
240+ && !completions_model_removed
241+ && !embeddings_model_removed
242+ && !tensor_model_removed
243+ {
233244 tracing:: debug!(
234- "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}" ,
245+ "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {} " ,
235246 model_name,
236247 chat_model_removed,
237248 completions_model_removed,
238- embeddings_model_removed
249+ embeddings_model_removed,
250+ tensor_model_removed
239251 ) ;
240252 } else {
241253 for model_type in ALL_MODEL_TYPES {
242254 if ( ( chat_model_removed && * model_type == ModelType :: Chat )
243255 || ( completions_model_removed && * model_type == ModelType :: Completions )
244- || ( embeddings_model_removed && * model_type == ModelType :: Embedding ) )
256+ || ( embeddings_model_removed && * model_type == ModelType :: Embedding )
257+ || ( tensor_model_removed && * model_type == ModelType :: TensorBased ) )
245258 && let Some ( tx) = & self . model_update_tx
246259 {
247260 tx. send ( ModelUpdate :: Removed ( * model_type) ) . await . ok ( ) ;
@@ -421,11 +434,24 @@ impl ModelWatcher {
421434
422435 self . manager
423436 . add_embeddings_model ( & model_entry. name , embedding_engine) ?;
437+ } else if model_entry. model_input == ModelInput :: Tensor
438+ && model_entry. model_type . supports_tensor ( )
439+ {
440+ // Case 5: Tensor + Tensor (non-LLM)
441+ let push_router = PushRouter :: <
442+ NvCreateTensorRequest ,
443+ Annotated < NvCreateTensorResponse > ,
444+ > :: from_client_with_threshold (
445+ client, self . router_mode , self . busy_threshold
446+ )
447+ . await ?;
448+ let engine = Arc :: new ( push_router) ;
449+ self . manager . add_tensor_model ( & model_entry. name , engine) ?;
424450 } else {
425451 // Reject unsupported combinations
426452 anyhow:: bail!(
427453 "Unsupported model configuration: {} with {} input. Supported combinations: \
428- Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings",
454+ Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased ",
429455 model_entry. model_type,
430456 model_entry. model_input. as_str( )
431457 ) ;
0 commit comments