Skip to content

Commit 355271b

Browse files
committed
skip downloading model weights if using mocker (only tokenizer)
1 parent 97390ac commit 355271b

File tree

5 files changed

+39
-9
lines changed

5 files changed

+39
-9
lines changed

launch/dynamo-run/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ pub async fn run(
4040
.http_port(Some(flags.http_port))
4141
.router_config(Some(flags.router_config()))
4242
.request_template(flags.request_template.clone())
43-
.migration_limit(flags.migration_limit);
43+
.migration_limit(flags.migration_limit)
44+
.is_mocker(matches!(out_opt, Some(Output::Mocker)));
4445

4546
// TODO: old, address this later:
4647
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.

lib/bindings/python/rust/llm/entrypoint.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ pub fn make_engine<'p>(
156156
.request_template(args.template_file.clone())
157157
.kv_cache_block_size(args.kv_cache_block_size)
158158
.router_config(args.router_config.clone().map(|rc| rc.into()))
159-
.http_port(args.http_port);
159+
.http_port(args.http_port)
160+
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
160161
pyo3_async_runtimes::tokio::future_into_py(py, async move {
161162
let local_model = builder.build().await.map_err(to_pyerr)?;
162163
let inner = select_engine(distributed_runtime, args, local_model)

lib/llm/src/hub.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [
2727

2828
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
2929

30+
/// Checks if a file is a model weight file
31+
fn is_weight_file(filename: &str) -> bool {
32+
filename.ends_with(".bin")
33+
|| filename.ends_with(".safetensors")
34+
|| filename.ends_with(".h5")
35+
|| filename.ends_with(".msgpack")
36+
|| filename.ends_with(".ckpt.index")
37+
}
38+
3039
/// Attempt to download a model from Hugging Face
3140
/// Returns the directory it is in
32-
pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
41+
/// If ignore_weights is true, model weight files will be skipped
42+
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
3343
let name = name.as_ref();
3444
let token = env::var(HF_TOKEN_ENV_VAR).ok();
3545
let api = ApiBuilder::new()
@@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
6676
continue;
6777
}
6878

79+
// If ignore_weights is true, skip weight files
80+
if ignore_weights && is_weight_file(&sib.rfilename) {
81+
continue;
82+
}
83+
6984
match repo.get(&sib.rfilename).await {
7085
Ok(path) => {
7186
p = path;
@@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
8398
}
8499

85100
if !files_downloaded {
101+
let file_type = if ignore_weights {
102+
"non-weight"
103+
} else {
104+
"valid"
105+
};
86106
return Err(anyhow::anyhow!(
87-
"No valid files found for model '{}'.",
107+
"No {} files found for model '{}'.",
108+
file_type,
88109
model_name
89110
));
90111
}

lib/llm/src/local_model.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub struct LocalModelBuilder {
4747
kv_cache_block_size: u32,
4848
http_port: u16,
4949
migration_limit: u32,
50+
is_mocker: bool,
5051
}
5152

5253
impl Default for LocalModelBuilder {
@@ -62,6 +63,7 @@ impl Default for LocalModelBuilder {
6263
template_file: Default::default(),
6364
router_config: Default::default(),
6465
migration_limit: Default::default(),
66+
is_mocker: Default::default(),
6567
}
6668
}
6769
}
@@ -119,6 +121,11 @@ impl LocalModelBuilder {
119121
self
120122
}
121123

124+
pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
125+
self.is_mocker = is_mocker;
126+
self
127+
}
128+
122129
/// Make an LLM ready for use:
123130
/// - Download it from Hugging Face (and NGC in future) if necessary
124131
/// - Resolve the path
@@ -169,7 +176,7 @@ impl LocalModelBuilder {
169176
let relative_path = model_path.trim_start_matches(HF_SCHEME);
170177
let full_path = if is_hf_repo {
171178
// HF download if necessary
172-
super::hub::from_hf(relative_path).await?
179+
super::hub::from_hf(relative_path, self.is_mocker).await?
173180
} else {
174181
fs::canonicalize(relative_path)?
175182
};

tests/router/test_router_e2e_with_mockers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import json
66
import logging
77
import os
8+
import time
89

910
import aiohttp
1011
import pytest
1112

12-
from tests.conftest import download_models
1313
from tests.utils.managed_process import ManagedProcess
1414

1515
pytestmark = pytest.mark.pre_merge
@@ -96,9 +96,6 @@ def test_mocker_kv_router(request, runtime_services):
9696
This test doesn't require GPUs and runs quickly for pre-merge validation.
9797
"""
9898

99-
# Download only the Qwen model for this test
100-
download_models([MODEL_NAME])
101-
10299
# runtime_services starts etcd and nats
103100
logger.info("Starting mocker KV router test")
104101

@@ -132,6 +129,9 @@ def test_mocker_kv_router(request, runtime_services):
132129
for mocker in mocker_processes:
133130
mocker.__enter__()
134131

132+
# Give 2 seconds for setups (download tokenizer)
133+
time.sleep(2)
134+
135135
# Send test requests
136136
test_payload = {
137137
"model": MODEL_NAME,

0 commit comments

Comments
 (0)