Skip to content

Commit b26328d

Browse files
committed
add verify load_format in vllmconfig
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
1 parent 28f1cee commit b26328d

File tree

3 files changed

+12
-30
lines changed

3 files changed

+12
-30
lines changed

tests/engine/test_arg_utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -332,24 +332,3 @@ def test_human_readable_model_len():
332332
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
333333
with pytest.raises(ArgumentError):
334334
args = parser.parse_args(["--max-model-len", invalid])
335-
336-
337-
def test_load_format():
338-
args = EngineArgs(model="s3://model/Qwen/Qwen3-0.6B")
339-
args.create_model_config()
340-
assert args.load_format == "runai_streamer"
341-
342-
args = EngineArgs(model="s3://model/Qwen/Qwen3-0.6B",
343-
load_format="runai_streamer")
344-
args.create_model_config()
345-
assert args.load_format == "runai_streamer"
346-
347-
try:
348-
args = EngineArgs(model="s3://model/Qwen/Qwen3-0.6B",
349-
load_format="gguf")
350-
args.create_model_config()
351-
except Exception as e:
352-
assert isinstance(e, ValueError)
353-
assert str(e) == ("To load a model from S3, "
354-
"'load_format' must be 'runai_streamer', "
355-
"but got 'gguf'. Model: s3://model/Qwen/Qwen3-0.6B")

vllm/config/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,6 +3025,16 @@ def try_verify_and_update_config(self):
30253025
SequenceClassificationConfig)
30263026
SequenceClassificationConfig.verify_and_update_config(self)
30273027

3028+
if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
3029+
self.model_config.model_weights):
3030+
if self.load_config.load_format == "auto":
3031+
self.load_config.load_format = "runai_streamer"
3032+
elif self.load_config.load_format != "runai_streamer":
3033+
raise ValueError(f"To load a model from S3, 'load_format' "
3034+
f"must be 'runai_streamer', "
3035+
f"but got '{self.load_config.load_format}'. "
3036+
f"Model: {self.model_config.model}")
3037+
30283038
def __str__(self):
30293039
return (
30303040
f"model={self.model_config.model!r}, "

vllm/engine/arg_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.reasoning import ReasoningParserManager
4444
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
4545
from vllm.transformers_utils.config import get_model_path, is_interleaved
46-
from vllm.transformers_utils.utils import check_gguf_file, is_s3
46+
from vllm.transformers_utils.utils import check_gguf_file
4747
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
4848
GiB_bytes, get_ip, is_in_ray_actor)
4949
from vllm.v1.sample.logits_processor import LogitsProcessor
@@ -491,6 +491,7 @@ def __post_init__(self):
491491
# Setup plugins
492492
from vllm.plugins import load_general_plugins
493493
load_general_plugins()
494+
# when use hf offline,replace model id to local model path
494495
if huggingface_hub.constants.HF_HUB_OFFLINE:
495496
model_id = self.model
496497
self.model = get_model_path(self.model, self.revision)
@@ -959,14 +960,6 @@ def create_model_config(self) -> ModelConfig:
959960
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
960961
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
961962

962-
if is_s3(self.model):
963-
if self.load_format == "auto":
964-
self.load_format = "runai_streamer"
965-
elif self.load_format != "runai_streamer":
966-
raise ValueError(
967-
f"To load a model from S3, 'load_format' "
968-
f"must be 'runai_streamer', "
969-
f"but got '{self.load_format}'. Model: {self.model}")
970963
if self.disable_mm_preprocessor_cache:
971964
logger.warning(
972965
"`--disable-mm-preprocessor-cache` is deprecated "

0 commit comments

Comments
 (0)