Skip to content

Commit 5823145

Browse files
pwschuurmanomer-dayanDarkLight1337
authored andcommitted
[Bugfix] Update Run:AI Model Streamer Loading Integration (vllm-project#23845)
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai> Signed-off-by: Peter Schuurman <psch@google.com> Co-authored-by: Omer Dayan (SW-GPU) <omer@run.ai> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 67de548 commit 5823145

File tree

7 files changed

+188
-123
lines changed

7 files changed

+188
-123
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,10 @@ def _read_requirements(filename: str) -> list[str]:
656656
"bench": ["pandas", "datasets"],
657657
"tensorizer": ["tensorizer==2.10.1"],
658658
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
659-
"runai":
660-
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
659+
"runai": [
660+
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
661+
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
662+
],
661663
"audio": ["librosa", "soundfile",
662664
"mistral_common[audio]"], # Required for audio processing
663665
"video": [], # Kept for backwards compatibility
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import glob
5+
import os
6+
import tempfile
7+
8+
import huggingface_hub.constants
9+
10+
from vllm.model_executor.model_loader.weight_utils import (
11+
download_weights_from_hf)
12+
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
13+
list_safetensors)
14+
15+
16+
def test_is_runai_obj_uri():
17+
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
18+
assert is_runai_obj_uri("s3://some-s3-bucket/path")
19+
assert not is_runai_obj_uri("nfs://some-nfs-path")
20+
21+
22+
def test_runai_list_safetensors_local():
23+
with tempfile.TemporaryDirectory() as tmpdir:
24+
huggingface_hub.constants.HF_HUB_OFFLINE = False
25+
download_weights_from_hf("openai-community/gpt2",
26+
allow_patterns=["*.safetensors", "*.json"],
27+
cache_dir=tmpdir)
28+
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
29+
assert len(safetensors) > 0
30+
parentdir = [
31+
os.path.dirname(safetensor) for safetensor in safetensors
32+
][0]
33+
files = list_safetensors(parentdir)
34+
assert len(safetensors) == len(files)
35+
36+
37+
if __name__ == "__main__":
38+
test_is_runai_obj_uri()
39+
test_runai_list_safetensors_local()

vllm/config/__init__.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
is_interleaved, maybe_override_with_speculators_target_model,
4949
try_get_generation_config, try_get_safetensors_metadata,
5050
try_get_tokenizer_config, uses_mrope)
51-
from vllm.transformers_utils.s3_utils import S3Model
52-
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
51+
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
52+
is_runai_obj_uri)
53+
from vllm.transformers_utils.utils import maybe_model_redirect
5354
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
5455
STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType,
5556
LazyLoader, common_broadcastable_dtype, random_uuid)
@@ -556,15 +557,6 @@ def __post_init__(self) -> None:
556557
"affect the random state of the Python process that "
557558
"launched vLLM.", self.seed)
558559

559-
if self.runner != "draft":
560-
# If we're not running the draft model, check for speculators config
561-
# If speculators config, set model / tokenizer to be target model
562-
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
563-
model=self.model,
564-
tokenizer=self.tokenizer,
565-
revision=self.revision,
566-
trust_remote_code=self.trust_remote_code)
567-
568560
# Keep set served_model_name before maybe_model_redirect(self.model)
569561
self.served_model_name = get_served_model_name(self.model,
570562
self.served_model_name)
@@ -603,7 +595,16 @@ def __post_init__(self) -> None:
603595
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
604596
warnings.warn(DeprecationWarning(msg), stacklevel=2)
605597

606-
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
598+
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
599+
600+
if self.runner != "draft":
601+
# If we're not running the draft model, check for speculators config
602+
# If speculators config, set model / tokenizer to be target model
603+
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
604+
model=self.model,
605+
tokenizer=self.tokenizer,
606+
revision=self.revision,
607+
trust_remote_code=self.trust_remote_code)
607608

608609
if (backend := envs.VLLM_ATTENTION_BACKEND
609610
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
@@ -832,41 +833,42 @@ def architecture(self) -> str:
832833
"""The architecture vllm actually used."""
833834
return self._architecture
834835

835-
def maybe_pull_model_tokenizer_for_s3(self, model: str,
836-
tokenizer: str) -> None:
837-
"""Pull model/tokenizer from S3 to temporary directory when needed.
836+
def maybe_pull_model_tokenizer_for_runai(self, model: str,
837+
tokenizer: str) -> None:
838+
"""Pull model/tokenizer from Object Storage to temporary
839+
directory when needed.
838840
839841
Args:
840842
model: Model name or path
841843
tokenizer: Tokenizer name or path
842844
"""
843-
if not (is_s3(model) or is_s3(tokenizer)):
845+
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
844846
return
845847

846-
if is_s3(model):
847-
s3_model = S3Model()
848-
s3_model.pull_files(model,
849-
allow_pattern=["*.model", "*.py", "*.json"])
848+
if is_runai_obj_uri(model):
849+
object_storage_model = ObjectStorageModel()
850+
object_storage_model.pull_files(
851+
model, allow_pattern=["*.model", "*.py", "*.json"])
850852
self.model_weights = model
851-
self.model = s3_model.dir
853+
self.model = object_storage_model.dir
852854

853855
# If tokenizer is same as model, download to same directory
854856
if model == tokenizer:
855-
s3_model.pull_files(model,
856-
ignore_pattern=[
857-
"*.pt", "*.safetensors", "*.bin",
858-
"*.tensors"
859-
])
860-
self.tokenizer = s3_model.dir
857+
object_storage_model.pull_files(model,
858+
ignore_pattern=[
859+
"*.pt", "*.safetensors",
860+
"*.bin", "*.tensors"
861+
])
862+
self.tokenizer = object_storage_model.dir
861863
return
862864

863865
# Only download tokenizer if needed and not already handled
864-
if is_s3(tokenizer):
865-
s3_tokenizer = S3Model()
866-
s3_tokenizer.pull_files(
866+
if is_runai_obj_uri(tokenizer):
867+
object_storage_tokenizer = ObjectStorageModel()
868+
object_storage_tokenizer.pull_files(
867869
model,
868870
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
869-
self.tokenizer = s3_tokenizer.dir
871+
self.tokenizer = object_storage_tokenizer.dir
870872

871873
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
872874
if self._model_info.supports_multimodal:

vllm/engine/arg_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,9 +1053,10 @@ def create_speculative_config(
10531053
SpeculatorsConfig)
10541054

10551055
if self.speculative_config is None:
1056-
hf_config = get_config(self.hf_config_path or self.model,
1057-
self.trust_remote_code, self.revision,
1058-
self.code_revision, self.config_format)
1056+
hf_config = get_config(
1057+
self.hf_config_path or target_model_config.model,
1058+
self.trust_remote_code, self.revision, self.code_revision,
1059+
self.config_format)
10591060

10601061
# if loading a SpeculatorsConfig, load the speculative_config
10611062
# details from the config directly
@@ -1065,7 +1066,7 @@ def create_speculative_config(
10651066
self.speculative_config = {}
10661067
self.speculative_config[
10671068
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
1068-
self.speculative_config["model"] = self.model
1069+
self.speculative_config["model"] = target_model_config.model
10691070
self.speculative_config["method"] = hf_config.method
10701071
else:
10711072
return None

vllm/model_executor/model_loader/runai_streamer_loader.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
# ruff: noqa: SIM117
4-
import glob
54
import os
65
from collections.abc import Generator
76
from typing import Optional
@@ -15,8 +14,8 @@
1514
from vllm.model_executor.model_loader.weight_utils import (
1615
download_safetensors_index_file_from_hf, download_weights_from_hf,
1716
runai_safetensors_weights_iterator)
18-
from vllm.transformers_utils.s3_utils import glob as s3_glob
19-
from vllm.transformers_utils.utils import is_s3
17+
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
18+
list_safetensors)
2019

2120

2221
class RunaiModelStreamerLoader(BaseModelLoader):
@@ -53,27 +52,22 @@ def _prepare_weights(self, model_name_or_path: str,
5352
5453
If the model is not local, it will be downloaded."""
5554

56-
is_s3_path = is_s3(model_name_or_path)
55+
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
5756
is_local = os.path.isdir(model_name_or_path)
5857
safetensors_pattern = "*.safetensors"
5958
index_file = SAFE_WEIGHTS_INDEX_NAME
6059

61-
hf_folder = (model_name_or_path if
62-
(is_local or is_s3_path) else download_weights_from_hf(
60+
hf_folder = (model_name_or_path if (is_local or is_object_storage_path)
61+
else download_weights_from_hf(
6362
model_name_or_path,
6463
self.load_config.download_dir,
6564
[safetensors_pattern],
6665
revision,
6766
ignore_patterns=self.load_config.ignore_patterns,
6867
))
69-
if is_s3_path:
70-
hf_weights_files = s3_glob(path=hf_folder,
71-
allow_pattern=[safetensors_pattern])
72-
else:
73-
hf_weights_files = glob.glob(
74-
os.path.join(hf_folder, safetensors_pattern))
75-
76-
if not is_local and not is_s3_path:
68+
hf_weights_files = list_safetensors(path=hf_folder)
69+
70+
if not is_local and not is_object_storage_path:
7771
download_safetensors_index_file_from_hf(
7872
model_name_or_path, index_file, self.load_config.download_dir,
7973
revision)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
import shutil
6+
import signal
7+
import tempfile
8+
from typing import Optional
9+
10+
from vllm.logger import init_logger
11+
from vllm.utils import PlaceholderModule
12+
13+
logger = init_logger(__name__)
14+
15+
SUPPORTED_SCHEMES = ['s3://', 'gs://']
16+
17+
try:
18+
from runai_model_streamer import list_safetensors as runai_list_safetensors
19+
from runai_model_streamer import pull_files as runai_pull_files
20+
except (ImportError, OSError):
21+
# see https://github.com/run-ai/runai-model-streamer/issues/26
22+
# OSError will be raised on arm64 platform
23+
runai_model_streamer = PlaceholderModule(
24+
"runai_model_streamer") # type: ignore[assignment]
25+
runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
26+
runai_list_safetensors = runai_model_streamer.placeholder_attr(
27+
"list_safetensors")
28+
29+
30+
def list_safetensors(path: str = "") -> list[str]:
31+
"""
32+
List full file names from object path and filter by allow pattern.
33+
34+
Args:
35+
path: The object storage path to list from.
36+
allow_pattern: A list of patterns of which files to pull.
37+
38+
Returns:
39+
list[str]: List of full object storage paths allowed by the pattern
40+
"""
41+
return runai_list_safetensors(path)
42+
43+
44+
def is_runai_obj_uri(model_or_path: str) -> bool:
45+
return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES))
46+
47+
48+
class ObjectStorageModel:
49+
"""
50+
A class representing an ObjectStorage model mirrored into a
51+
temporary directory.
52+
53+
Attributes:
54+
dir: The temporary created directory.
55+
56+
Methods:
57+
pull_files(): Pull model from object storage to the temporary
58+
directory.
59+
"""
60+
61+
def __init__(self) -> None:
62+
for sig in (signal.SIGINT, signal.SIGTERM):
63+
existing_handler = signal.getsignal(sig)
64+
signal.signal(sig, self._close_by_signal(existing_handler))
65+
66+
self.dir = tempfile.mkdtemp()
67+
68+
def __del__(self):
69+
self._close()
70+
71+
def _close(self) -> None:
72+
if os.path.exists(self.dir):
73+
shutil.rmtree(self.dir)
74+
75+
def _close_by_signal(self, existing_handler=None):
76+
77+
def new_handler(signum, frame):
78+
self._close()
79+
if existing_handler:
80+
existing_handler(signum, frame)
81+
82+
return new_handler
83+
84+
def pull_files(self,
85+
model_path: str = "",
86+
allow_pattern: Optional[list[str]] = None,
87+
ignore_pattern: Optional[list[str]] = None) -> None:
88+
"""
89+
Pull files from object storage into the temporary directory.
90+
91+
Args:
92+
model_path: The object storage path of the model.
93+
allow_pattern: A list of patterns of which files to pull.
94+
ignore_pattern: A list of patterns of which files not to pull.
95+
96+
"""
97+
if not model_path.endswith("/"):
98+
model_path = model_path + "/"
99+
runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern)

0 commit comments

Comments
 (0)