Skip to content

Commit 712e99e

Browse files
committed
Update vLLM to use latest version of Run:AI Model Streamer
Signed-off-by: Peter Schuurman <psch@google.com>
1 parent baa3e38 commit 712e99e

File tree

6 files changed

+174
-110
lines changed

6 files changed

+174
-110
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,8 +688,10 @@ def _read_requirements(filename: str) -> list[str]:
688688
"bench": ["pandas", "datasets"],
689689
"tensorizer": ["tensorizer==2.10.1"],
690690
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
691-
"runai":
692-
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
691+
"runai": [
692+
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
693+
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
694+
],
693695
"audio": ["librosa", "soundfile",
694696
"mistral_common[audio]"], # Required for audio processing
695697
"video": [], # Kept for backwards compatibility
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import glob
5+
import os
6+
import tempfile
7+
8+
import huggingface_hub.constants
9+
10+
from vllm.model_executor.model_loader.weight_utils import (
11+
download_weights_from_hf)
12+
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
13+
list_safetensors)
14+
15+
16+
def test_is_runai_obj_uri():
17+
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
18+
assert is_runai_obj_uri("s3://some-s3-bucket/path")
19+
assert not is_runai_obj_uri("nfs://some-nfs-path")
20+
21+
22+
def test_runai_list_safetensors_local():
23+
with tempfile.TemporaryDirectory() as tmpdir:
24+
huggingface_hub.constants.HF_HUB_OFFLINE = False
25+
download_weights_from_hf("openai-community/gpt2",
26+
allow_patterns=["*.safetensors", "*.json"],
27+
cache_dir=tmpdir)
28+
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
29+
assert len(safetensors) > 0
30+
parentdir = [
31+
os.path.dirname(safetensor) for safetensor in safetensors
32+
][0]
33+
files = list_safetensors(parentdir)
34+
assert len(safetensors) == len(files)
35+
36+
37+
if __name__ == "__main__":
38+
test_is_runai_obj_uri()
39+
test_runai_list_safetensors_local()

vllm/config/__init__.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@
4747
is_interleaved, maybe_override_with_speculators_target_model,
4848
try_get_generation_config, try_get_safetensors_metadata,
4949
try_get_tokenizer_config, uses_mrope)
50-
from vllm.transformers_utils.s3_utils import S3Model
51-
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
50+
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
51+
is_runai_obj_uri)
52+
from vllm.transformers_utils.utils import maybe_model_redirect
5253
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
5354
STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType,
5455
LazyLoader, common_broadcastable_dtype, random_uuid)
@@ -598,7 +599,7 @@ def __post_init__(self) -> None:
598599
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
599600
warnings.warn(DeprecationWarning(msg), stacklevel=2)
600601

601-
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
602+
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
602603

603604
if self.runner != "draft":
604605
# If we're not running the draft model, check for speculators config
@@ -840,41 +841,42 @@ def architecture(self) -> str:
840841
"""The architecture vllm actually used."""
841842
return self._architecture
842843

843-
def maybe_pull_model_tokenizer_for_s3(self, model: str,
844-
tokenizer: str) -> None:
845-
"""Pull model/tokenizer from S3 to temporary directory when needed.
844+
def maybe_pull_model_tokenizer_for_runai(self, model: str,
845+
tokenizer: str) -> None:
846+
"""Pull model/tokenizer from Object Storage to temporary
847+
directory when needed.
846848
847849
Args:
848850
model: Model name or path
849851
tokenizer: Tokenizer name or path
850852
"""
851-
if not (is_s3(model) or is_s3(tokenizer)):
853+
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
852854
return
853855

854-
if is_s3(model):
855-
s3_model = S3Model()
856-
s3_model.pull_files(model,
857-
allow_pattern=["*.model", "*.py", "*.json"])
856+
if is_runai_obj_uri(model):
857+
object_storage_model = ObjectStorageModel()
858+
object_storage_model.pull_files(
859+
model, allow_pattern=["*.model", "*.py", "*.json"])
858860
self.model_weights = model
859-
self.model = s3_model.dir
861+
self.model = object_storage_model.dir
860862

861863
# If tokenizer is same as model, download to same directory
862864
if model == tokenizer:
863-
s3_model.pull_files(model,
864-
ignore_pattern=[
865-
"*.pt", "*.safetensors", "*.bin",
866-
"*.tensors"
867-
])
868-
self.tokenizer = s3_model.dir
865+
object_storage_model.pull_files(model,
866+
ignore_pattern=[
867+
"*.pt", "*.safetensors",
868+
"*.bin", "*.tensors"
869+
])
870+
self.tokenizer = object_storage_model.dir
869871
return
870872

871873
# Only download tokenizer if needed and not already handled
872-
if is_s3(tokenizer):
873-
s3_tokenizer = S3Model()
874-
s3_tokenizer.pull_files(
874+
if is_runai_obj_uri(tokenizer):
875+
object_storage_tokenizer = ObjectStorageModel()
876+
object_storage_tokenizer.pull_files(
875877
model,
876878
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
877-
self.tokenizer = s3_tokenizer.dir
879+
self.tokenizer = object_storage_tokenizer.dir
878880

879881
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
880882
if self._model_info.supports_multimodal:

vllm/model_executor/model_loader/runai_streamer_loader.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
# ruff: noqa: SIM117
4-
import glob
54
import os
65
from collections.abc import Generator
76
from typing import Optional
@@ -15,8 +14,8 @@
1514
from vllm.model_executor.model_loader.weight_utils import (
1615
download_safetensors_index_file_from_hf, download_weights_from_hf,
1716
runai_safetensors_weights_iterator)
18-
from vllm.transformers_utils.s3_utils import glob as s3_glob
19-
from vllm.transformers_utils.utils import is_s3
17+
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
18+
list_safetensors)
2019

2120

2221
class RunaiModelStreamerLoader(BaseModelLoader):
@@ -53,27 +52,22 @@ def _prepare_weights(self, model_name_or_path: str,
5352
5453
If the model is not local, it will be downloaded."""
5554

56-
is_s3_path = is_s3(model_name_or_path)
55+
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
5756
is_local = os.path.isdir(model_name_or_path)
5857
safetensors_pattern = "*.safetensors"
5958
index_file = SAFE_WEIGHTS_INDEX_NAME
6059

61-
hf_folder = (model_name_or_path if
62-
(is_local or is_s3_path) else download_weights_from_hf(
60+
hf_folder = (model_name_or_path if (is_local or is_object_storage_path)
61+
else download_weights_from_hf(
6362
model_name_or_path,
6463
self.load_config.download_dir,
6564
[safetensors_pattern],
6665
revision,
6766
ignore_patterns=self.load_config.ignore_patterns,
6867
))
69-
if is_s3_path:
70-
hf_weights_files = s3_glob(path=hf_folder,
71-
allow_pattern=[safetensors_pattern])
72-
else:
73-
hf_weights_files = glob.glob(
74-
os.path.join(hf_folder, safetensors_pattern))
75-
76-
if not is_local and not is_s3_path:
68+
hf_weights_files = list_safetensors(path=hf_folder)
69+
70+
if not is_local and not is_object_storage_path:
7771
download_safetensors_index_file_from_hf(
7872
model_name_or_path, index_file, self.load_config.download_dir,
7973
revision)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
import shutil
6+
import signal
7+
import tempfile
8+
from typing import Optional
9+
10+
from vllm.logger import init_logger
11+
from vllm.utils import PlaceholderModule
12+
13+
logger = init_logger(__name__)
14+
15+
SUPPORTED_SCHEMES = ['s3://', 'gs://']
16+
17+
try:
18+
from runai_model_streamer import list_safetensors as runai_list_safetensors
19+
from runai_model_streamer import pull_files as runai_pull_files
20+
except (ImportError, OSError):
21+
# see https://github.com/run-ai/runai-model-streamer/issues/26
22+
# OSError will be raised on arm64 platform
23+
runai_model_streamer = PlaceholderModule(
24+
"runai_model_streamer") # type: ignore[assignment]
25+
runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
26+
runai_list_safetensors = runai_model_streamer.placeholder_attr(
27+
"list_safetensors")
28+
29+
30+
def list_safetensors(path: str = "") -> list[str]:
31+
"""
32+
List full file names from object path and filter by allow pattern.
33+
34+
Args:
35+
path: The object storage path to list from.
36+
allow_pattern: A list of patterns of which files to pull.
37+
38+
Returns:
39+
list[str]: List of full object storage paths allowed by the pattern
40+
"""
41+
return runai_list_safetensors(path)
42+
43+
44+
def is_runai_obj_uri(model_or_path: str) -> bool:
45+
return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES))
46+
47+
48+
class ObjectStorageModel:
49+
"""
50+
A class representing an ObjectStorage model mirrored into a
51+
temporary directory.
52+
53+
Attributes:
54+
dir: The temporary created directory.
55+
56+
Methods:
57+
pull_files(): Pull model from object storage to the temporary
58+
directory.
59+
"""
60+
61+
def __init__(self) -> None:
62+
for sig in (signal.SIGINT, signal.SIGTERM):
63+
existing_handler = signal.getsignal(sig)
64+
signal.signal(sig, self._close_by_signal(existing_handler))
65+
66+
self.dir = tempfile.mkdtemp()
67+
68+
def __del__(self):
69+
self._close()
70+
71+
def _close(self) -> None:
72+
if os.path.exists(self.dir):
73+
shutil.rmtree(self.dir)
74+
75+
def _close_by_signal(self, existing_handler=None):
76+
77+
def new_handler(signum, frame):
78+
self._close()
79+
if existing_handler:
80+
existing_handler(signum, frame)
81+
82+
return new_handler
83+
84+
def pull_files(self,
85+
model_path: str = "",
86+
allow_pattern: Optional[list[str]] = None,
87+
ignore_pattern: Optional[list[str]] = None) -> None:
88+
"""
89+
Pull files from object storage into the temporary directory.
90+
91+
Args:
92+
model_path: The object storage path of the model.
93+
allow_pattern: A list of patterns of which files to pull.
94+
ignore_pattern: A list of patterns of which files not to pull.
95+
96+
"""
97+
if not model_path.endswith("/"):
98+
model_path = model_path + "/"
99+
runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern)

vllm/transformers_utils/s3_utils.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import fnmatch
5-
import os
6-
import shutil
7-
import signal
8-
import tempfile
9-
from pathlib import Path
105
from typing import Optional
116

127
from vllm.utils import PlaceholderModule
@@ -93,70 +88,3 @@ def list_files(
9388
paths = _filter_ignore(paths, ignore_pattern)
9489

9590
return bucket_name, prefix, paths
96-
97-
98-
class S3Model:
99-
"""
100-
A class representing a S3 model mirrored into a temporary directory.
101-
102-
Attributes:
103-
s3: S3 client.
104-
dir: The temporary created directory.
105-
106-
Methods:
107-
pull_files(): Pull model from S3 to the temporary directory.
108-
"""
109-
110-
def __init__(self) -> None:
111-
self.s3 = boto3.client('s3')
112-
for sig in (signal.SIGINT, signal.SIGTERM):
113-
existing_handler = signal.getsignal(sig)
114-
signal.signal(sig, self._close_by_signal(existing_handler))
115-
116-
self.dir = tempfile.mkdtemp()
117-
118-
def __del__(self):
119-
self._close()
120-
121-
def _close(self) -> None:
122-
if os.path.exists(self.dir):
123-
shutil.rmtree(self.dir)
124-
125-
def _close_by_signal(self, existing_handler=None):
126-
127-
def new_handler(signum, frame):
128-
self._close()
129-
if existing_handler:
130-
existing_handler(signum, frame)
131-
132-
return new_handler
133-
134-
def pull_files(self,
135-
s3_model_path: str = "",
136-
allow_pattern: Optional[list[str]] = None,
137-
ignore_pattern: Optional[list[str]] = None) -> None:
138-
"""
139-
Pull files from S3 storage into the temporary directory.
140-
141-
Args:
142-
s3_model_path: The S3 path of the model.
143-
allow_pattern: A list of patterns of which files to pull.
144-
ignore_pattern: A list of patterns of which files not to pull.
145-
146-
"""
147-
if not s3_model_path.endswith("/"):
148-
s3_model_path = s3_model_path + "/"
149-
150-
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
151-
allow_pattern,
152-
ignore_pattern)
153-
if len(files) == 0:
154-
return
155-
156-
for file in files:
157-
destination_file = os.path.join(
158-
self.dir,
159-
file.removeprefix(base_dir).lstrip("/"))
160-
local_dir = Path(destination_file).parent
161-
os.makedirs(local_dir, exist_ok=True)
162-
self.s3.download_file(bucket_name, file, destination_file)

0 commit comments

Comments
 (0)