Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion launch/dynamo-run/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ pub async fn run(
.http_port(Some(flags.http_port))
.router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
.migration_limit(flags.migration_limit)
.is_mocker(matches!(out_opt, Some(Output::Mocker)));

// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
Expand Down
3 changes: 2 additions & 1 deletion lib/bindings/python/rust/llm/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ pub fn make_engine<'p>(
.request_template(args.template_file.clone())
.kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port);
.http_port(args.http_port)
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
Expand Down
25 changes: 23 additions & 2 deletions lib/llm/src/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [

const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";

/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}

/// Attempt to download a model from Hugging Face
/// Returns the directory it is in
pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
/// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
Expand Down Expand Up @@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
continue;
}

// If ignore_weights is true, skip weight files
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}

match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
Expand All @@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
}

if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!(
"No valid files found for model '{}'.",
"No {} files found for model '{}'.",
file_type,
model_name
));
}
Expand Down
9 changes: 8 additions & 1 deletion lib/llm/src/local_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct LocalModelBuilder {
kv_cache_block_size: u32,
http_port: u16,
migration_limit: u32,
is_mocker: bool,
}

impl Default for LocalModelBuilder {
Expand All @@ -62,6 +63,7 @@ impl Default for LocalModelBuilder {
template_file: Default::default(),
router_config: Default::default(),
migration_limit: Default::default(),
is_mocker: Default::default(),
}
}
}
Expand Down Expand Up @@ -119,6 +121,11 @@ impl LocalModelBuilder {
self
}

pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
self.is_mocker = is_mocker;
self
}

/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
Expand Down Expand Up @@ -169,7 +176,7 @@ impl LocalModelBuilder {
let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo {
// HF download if necessary
super::hub::from_hf(relative_path).await?
super::hub::from_hf(relative_path, self.is_mocker).await?
} else {
fs::canonicalize(relative_path)?
};
Expand Down
60 changes: 48 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@
]


def download_models(model_list=None):
def download_models(model_list=None, ignore_weights=False):
"""Download models - can be called directly or via fixture

Args:
model_list: List of model IDs to download. If None, downloads TEST_MODELS.
ignore_weights: If True, skips downloading model weight files. Default is False.
"""
if model_list is None:
model_list = TEST_MODELS
Expand All @@ -65,15 +66,33 @@ def download_models(model_list=None):
from huggingface_hub import snapshot_download

for model_id in model_list:
logging.info(f"Pre-downloading model: {model_id}")
logging.info(
f"Pre-downloading {'model (no weights)' if ignore_weights else 'model'}: {model_id}"
)

try:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
)
if ignore_weights:
# Weight file patterns to exclude (based on hub.rs implementation)
weight_patterns = [
"*.bin",
"*.safetensors",
"*.h5",
"*.msgpack",
"*.ckpt.index",
]

# Download everything except weight files
snapshot_download(
repo_id=model_id,
token=hf_token,
ignore_patterns=weight_patterns,
)
else:
# Download the full model snapshot (includes all files)
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logging.info(f"Successfully pre-downloaded: {model_id}")

except Exception as e:
Expand All @@ -94,6 +113,13 @@ def predownload_models():
yield


@pytest.fixture(scope="session")
def predownload_tokenizers():
"""Fixture wrapper around download_models for all TEST_MODELS"""
download_models(ignore_weights=True)
yield


@pytest.fixture(autouse=True)
def logger(request):
log_path = os.path.join(request.node.name, "test.log.txt")
Expand Down Expand Up @@ -127,14 +153,24 @@ def pytest_collection_modifyitems(config, items):
# Auto-inject predownload_models fixture for serve tests only (not router tests)
# Skip items that don't have fixturenames (like MypyFileItem)
if hasattr(item, "fixturenames"):
# Only apply to tests in the serve directory
# Guard clause: skip if already has the fixtures
if (
("serve" in str(item.path))
and ("predownload_models" not in item.fixturenames)
and (not item.get_closest_marker("skip_model_download"))
"predownload_models" in item.fixturenames
or "predownload_tokenizers" in item.fixturenames
):
continue

# Guard clause: skip if marked with skip_model_download
if item.get_closest_marker("skip_model_download"):
continue

# Add appropriate fixture based on test path
if "serve" in str(item.path):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")
elif "router" in str(item.path):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_tokenizers")


class EtcdServer(ManagedProcess):
Expand Down
4 changes: 0 additions & 4 deletions tests/router/test_router_e2e_with_mockers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import aiohttp
import pytest

from tests.conftest import download_models
from tests.utils.managed_process import ManagedProcess

pytestmark = pytest.mark.pre_merge
Expand Down Expand Up @@ -96,9 +95,6 @@ def test_mocker_kv_router(request, runtime_services):
This test doesn't require GPUs and runs quickly for pre-merge validation.
"""

# Download only the Qwen model for this test
download_models([MODEL_NAME])

# runtime_services starts etcd and nats
logger.info("Starting mocker KV router test")

Expand Down
Loading