Skip to content

Commit f0fcd0b

Browse files
GuanLuooandreeva-nvrmccorm4
authored andcommitted
feat: tensor type for generic inference. (#2746)
Signed-off-by: Guan Luo <gluo@nvidia.com> Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Co-authored-by: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Co-authored-by: Ryan McCormick <rmccormick@nvidia.com> Signed-off-by: Jason Zhou <jasonzho@nvidia.com>
1 parent 33c7171 commit f0fcd0b

File tree

18 files changed

+2024
-296
lines changed

18 files changed

+2024
-296
lines changed

lib/bindings/python/rust/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ fn register_llm<'p>(
165165
let model_input = match model_input {
166166
ModelInput::Text => llm_rs::model_type::ModelInput::Text,
167167
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
168+
ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
168169
};
169170

170171
let model_type_obj = model_type.inner;
@@ -298,6 +299,10 @@ impl ModelType {
298299
const Embedding: Self = ModelType {
299300
inner: llm_rs::model_type::ModelType::Embedding,
300301
};
302+
#[classattr]
303+
const TensorBased: Self = ModelType {
304+
inner: llm_rs::model_type::ModelType::TensorBased,
305+
};
301306

302307
fn __or__(&self, other: &Self) -> Self {
303308
ModelType {
@@ -315,6 +320,7 @@ impl ModelType {
315320
enum ModelInput {
316321
Text = 1,
317322
Tokens = 2,
323+
Tensor = 3,
318324
}
319325

320326
#[pymethods]

lib/bindings/python/rust/llm/local_model.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@ impl ModelRuntimeConfig {
5252
Ok(())
5353
}
5454

55+
fn set_tensor_model_config(
56+
&mut self,
57+
_py: Python<'_>,
58+
tensor_model_config: &Bound<'_, PyDict>,
59+
) -> PyResult<()> {
60+
let tensor_model_config = pythonize::depythonize(tensor_model_config).map_err(|err| {
61+
PyErr::new::<PyException, _>(format!("Failed to convert tensor_model_config: {}", err))
62+
})?;
63+
self.inner.tensor_model_config = Some(tensor_model_config);
64+
Ok(())
65+
}
66+
67+
fn get_tensor_model_config(&self, _py: Python<'_>) -> PyResult<Option<PyObject>> {
68+
if let Some(tensor_model_config) = &self.inner.tensor_model_config {
69+
let py_obj = pythonize::pythonize(_py, tensor_model_config).map_err(to_pyerr)?;
70+
Ok(Some(py_obj.unbind()))
71+
} else {
72+
Ok(None)
73+
}
74+
}
75+
5576
#[getter]
5677
fn total_kv_blocks(&self) -> Option<u64> {
5778
self.inner.total_kv_blocks

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,11 +849,11 @@ class HttpAsyncEngine:
849849
...
850850

851851
class ModelInput:
852-
"""What type of request this model needs: Text or Tokens"""
852+
"""What type of request this model needs: Text, Tokens or Tensor"""
853853
...
854854

855855
class ModelType:
856-
"""What type of request this model needs: Chat, Completions or Embedding"""
856+
"""What type of request this model needs: Chat, Completions, Embedding or Tensor"""
857857
...
858858

859859
class RouterMode:
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.
5+
6+
import os
7+
8+
import uvloop
9+
10+
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
11+
from dynamo.runtime import DistributedRuntime, dynamo_worker
12+
13+
TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
14+
15+
16+
@dynamo_worker(static=False)
17+
async def test_register(runtime: DistributedRuntime):
18+
component = runtime.namespace("test").component("tensor")
19+
await component.create_service()
20+
21+
endpoint = component.endpoint("generate")
22+
23+
model_config = {
24+
"name": "tensor",
25+
"inputs": [
26+
{"name": "input_text", "data_type": "Bytes", "shape": [-1]},
27+
{"name": "custom", "data_type": "Bytes", "shape": [-1]},
28+
{"name": "streaming", "data_type": "Bool", "shape": [1]},
29+
],
30+
"outputs": [{"name": "output_text", "data_type": "Bytes", "shape": [-1]}],
31+
}
32+
runtime_config = ModelRuntimeConfig()
33+
runtime_config.set_tensor_model_config(model_config)
34+
35+
assert model_config == runtime_config.get_tensor_model_config()
36+
37+
# [gluo FIXME] register_llm will attempt to load a LLM model,
38+
# which is not well-defined for Tensor yet. Currently provide
39+
# a valid model name to pass the registration.
40+
await register_llm(
41+
ModelInput.Tensor,
42+
ModelType.TensorBased,
43+
endpoint,
44+
"Qwen/Qwen3-0.6B",
45+
"tensor",
46+
runtime_config=runtime_config,
47+
)
48+
49+
if TEST_END_TO_END:
50+
await endpoint.serve_endpoint(generate)
51+
52+
53+
async def generate(request, context):
54+
print(f"Received request: {request}")
55+
# Echo input_text in output_text
56+
output_text = None
57+
streaming = False
58+
for tensor in request["tensors"]:
59+
if tensor["metadata"]["name"] == "input_text":
60+
input_text_str = "".join(map(chr, tensor["data"]["values"][0]))
61+
print(f"Input text: {input_text_str}")
62+
output_text = tensor
63+
output_text["metadata"]["name"] = "output_text"
64+
if tensor["metadata"]["name"] == "streaming":
65+
streaming = tensor["data"]["values"][0]
66+
if output_text is None:
67+
raise ValueError("input_text tensor not found in request")
68+
if streaming:
69+
for i in range(len(output_text["data"]["values"][0])):
70+
chunk = {
71+
"model": request["model"],
72+
"tensors": [
73+
{
74+
"metadata": output_text["metadata"],
75+
"data": {
76+
"data_type": output_text["data"]["data_type"],
77+
"values": [[output_text["data"]["values"][0][i]]],
78+
},
79+
}
80+
],
81+
}
82+
yield chunk
83+
else:
84+
yield {"model": request["model"], "tensors": [output_text]}
85+
86+
87+
if __name__ == "__main__":
88+
uvloop.run(test_register())

lib/llm/src/discovery/model_manager.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
1515
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
1616
use crate::{
1717
kv_router::KvRouter,
18+
types::generic::tensor::TensorStreamingEngine,
1819
types::openai::{
1920
chat_completions::OpenAIChatCompletionsStreamingEngine,
2021
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
@@ -36,6 +37,7 @@ pub struct ModelManager {
3637
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
3738
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
3839
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
40+
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
3941

4042
// These two are Mutex because we read and write rarely and equally
4143
entries: Mutex<HashMap<String, ModelEntry>>,
@@ -54,6 +56,7 @@ impl ModelManager {
5456
completion_engines: RwLock::new(ModelEngines::default()),
5557
chat_completion_engines: RwLock::new(ModelEngines::default()),
5658
embeddings_engines: RwLock::new(ModelEngines::default()),
59+
tensor_engines: RwLock::new(ModelEngines::default()),
5760
entries: Mutex::new(HashMap::new()),
5861
kv_choosers: Mutex::new(HashMap::new()),
5962
}
@@ -73,6 +76,7 @@ impl ModelManager {
7376
.into_iter()
7477
.chain(self.list_completions_models())
7578
.chain(self.list_embeddings_models())
79+
.chain(self.list_tensor_models())
7680
.collect()
7781
}
7882

@@ -88,6 +92,10 @@ impl ModelManager {
8892
self.embeddings_engines.read().list()
8993
}
9094

95+
pub fn list_tensor_models(&self) -> Vec<String> {
96+
self.tensor_engines.read().list()
97+
}
98+
9199
pub fn add_completions_model(
92100
&self,
93101
model: &str,
@@ -115,6 +123,15 @@ impl ModelManager {
115123
clients.add(model, engine)
116124
}
117125

126+
pub fn add_tensor_model(
127+
&self,
128+
model: &str,
129+
engine: TensorStreamingEngine,
130+
) -> Result<(), ModelManagerError> {
131+
let mut clients = self.tensor_engines.write();
132+
clients.add(model, engine)
133+
}
134+
118135
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
119136
let mut clients = self.completion_engines.write();
120137
clients.remove(model)
@@ -130,6 +147,11 @@ impl ModelManager {
130147
clients.remove(model)
131148
}
132149

150+
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
151+
let mut clients = self.tensor_engines.write();
152+
clients.remove(model)
153+
}
154+
133155
pub fn get_embeddings_engine(
134156
&self,
135157
model: &str,
@@ -163,6 +185,17 @@ impl ModelManager {
163185
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
164186
}
165187

188+
pub fn get_tensor_engine(
189+
&self,
190+
model: &str,
191+
) -> Result<TensorStreamingEngine, ModelManagerError> {
192+
self.tensor_engines
193+
.read()
194+
.get(model)
195+
.cloned()
196+
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
197+
}
198+
166199
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
167200
/// deleted from etcd.
168201
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {

lib/llm/src/discovery/watcher.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::{
3333
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
3434
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
3535
},
36+
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
3637
},
3738
};
3839

@@ -59,6 +60,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
5960
ModelType::Chat,
6061
ModelType::Completions,
6162
ModelType::Embedding,
63+
ModelType::TensorBased,
6264
];
6365

6466
impl ModelWatcher {
@@ -213,10 +215,12 @@ impl ModelWatcher {
213215
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
214216
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
215217
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
218+
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
216219

217220
let mut chat_model_removed = false;
218221
let mut completions_model_removed = false;
219222
let mut embeddings_model_removed = false;
223+
let mut tensor_model_removed = false;
220224

221225
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
222226
chat_model_removed = true;
@@ -228,20 +232,29 @@ impl ModelWatcher {
228232
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
229233
embeddings_model_removed = true;
230234
}
235+
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
236+
tensor_model_removed = true;
237+
}
231238

232-
if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
239+
if !chat_model_removed
240+
&& !completions_model_removed
241+
&& !embeddings_model_removed
242+
&& !tensor_model_removed
243+
{
233244
tracing::debug!(
234-
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}",
245+
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
235246
model_name,
236247
chat_model_removed,
237248
completions_model_removed,
238-
embeddings_model_removed
249+
embeddings_model_removed,
250+
tensor_model_removed
239251
);
240252
} else {
241253
for model_type in ALL_MODEL_TYPES {
242254
if ((chat_model_removed && *model_type == ModelType::Chat)
243255
|| (completions_model_removed && *model_type == ModelType::Completions)
244-
|| (embeddings_model_removed && *model_type == ModelType::Embedding))
256+
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
257+
|| (tensor_model_removed && *model_type == ModelType::TensorBased))
245258
&& let Some(tx) = &self.model_update_tx
246259
{
247260
tx.send(ModelUpdate::Removed(*model_type)).await.ok();
@@ -421,11 +434,24 @@ impl ModelWatcher {
421434

422435
self.manager
423436
.add_embeddings_model(&model_entry.name, embedding_engine)?;
437+
} else if model_entry.model_input == ModelInput::Tensor
438+
&& model_entry.model_type.supports_tensor()
439+
{
440+
// Case 5: Tensor + Tensor (non-LLM)
441+
let push_router = PushRouter::<
442+
NvCreateTensorRequest,
443+
Annotated<NvCreateTensorResponse>,
444+
>::from_client_with_threshold(
445+
client, self.router_mode, self.busy_threshold
446+
)
447+
.await?;
448+
let engine = Arc::new(push_router);
449+
self.manager.add_tensor_model(&model_entry.name, engine)?;
424450
} else {
425451
// Reject unsupported combinations
426452
anyhow::bail!(
427453
"Unsupported model configuration: {} with {} input. Supported combinations: \
428-
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings",
454+
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
429455
model_entry.model_type,
430456
model_entry.model_input.as_str()
431457
);

lib/llm/src/grpc/service.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
pub mod kserve;
55
pub mod openai;
6+
pub mod tensor;

0 commit comments

Comments
 (0)