Skip to content

Commit 43d9ad0

Browse files
BraveYsimon-mo
andauthored
[Model loader]: support multi-thread model weight loading (#23928)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
1 parent 7be141b commit 43d9ad0

File tree

2 files changed

+105
-12
lines changed

2 files changed

+105
-12
lines changed

vllm/model_executor/model_loader/default_loader.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
download_safetensors_index_file_from_hf, download_weights_from_hf,
1919
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
2020
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
21-
np_cache_weights_iterator, pt_weights_iterator,
22-
safetensors_weights_iterator)
21+
multi_thread_pt_weights_iterator,
22+
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
23+
pt_weights_iterator, safetensors_weights_iterator)
2324
from vllm.platforms import current_platform
2425

2526
logger = init_logger(__name__)
@@ -28,6 +29,9 @@
2829
class DefaultModelLoader(BaseModelLoader):
2930
"""Model loader that can load different file types from disk."""
3031

32+
# default number of thread when enable multithread weight loading
33+
DEFAULT_NUM_THREADS = 8
34+
3135
@dataclasses.dataclass
3236
class Source:
3337
"""A source for weights."""
@@ -52,9 +56,15 @@ class Source:
5256

5357
def __init__(self, load_config: LoadConfig):
5458
super().__init__(load_config)
55-
if load_config.model_loader_extra_config:
56-
raise ValueError(f"Model loader extra config is not supported for "
57-
f"load format {load_config.load_format}")
59+
60+
extra_config = load_config.model_loader_extra_config
61+
allowed_keys = {"enable_multithread_load", "num_threads"}
62+
unexpected_keys = set(extra_config.keys()) - allowed_keys
63+
64+
if unexpected_keys:
65+
raise ValueError(f"Unexpected extra config keys for load format "
66+
f"{load_config.load_format}: "
67+
f"{unexpected_keys}")
5868

5969
def _prepare_weights(
6070
self,
@@ -145,6 +155,7 @@ def _get_weights_iterator(
145155
self, source: "Source"
146156
) -> Generator[tuple[str, torch.Tensor], None, None]:
147157
"""Get an iterator for the model weights based on the load format."""
158+
extra_config = self.load_config.model_loader_extra_config
148159
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
149160
source.model_or_path, source.revision, source.fall_back_to_pt,
150161
source.allow_patterns_overrides)
@@ -165,16 +176,34 @@ def _get_weights_iterator(
165176
self.load_config.use_tqdm_on_load,
166177
)
167178
else:
168-
weights_iterator = safetensors_weights_iterator(
179+
if extra_config.get("enable_multithread_load"):
180+
weights_iterator = (
181+
multi_thread_safetensors_weights_iterator(
182+
hf_weights_files,
183+
self.load_config.use_tqdm_on_load,
184+
max_workers=extra_config.get(
185+
"num_threads", self.DEFAULT_NUM_THREADS),
186+
))
187+
else:
188+
weights_iterator = safetensors_weights_iterator(
189+
hf_weights_files,
190+
self.load_config.use_tqdm_on_load,
191+
)
192+
else:
193+
if extra_config.get("enable_multithread_load"):
194+
weights_iterator = multi_thread_pt_weights_iterator(
169195
hf_weights_files,
170196
self.load_config.use_tqdm_on_load,
197+
self.load_config.pt_load_map_location,
198+
max_workers=extra_config.get("num_threads",
199+
self.DEFAULT_NUM_THREADS),
200+
)
201+
else:
202+
weights_iterator = pt_weights_iterator(
203+
hf_weights_files,
204+
self.load_config.use_tqdm_on_load,
205+
self.load_config.pt_load_map_location,
171206
)
172-
else:
173-
weights_iterator = pt_weights_iterator(
174-
hf_weights_files,
175-
self.load_config.use_tqdm_on_load,
176-
self.load_config.pt_load_map_location,
177-
)
178207

179208
if current_platform.is_tpu():
180209
from vllm.platforms.tpu import USE_TPU_COMMONS

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Utilities for downloading and initializing model weights."""
4+
import concurrent.futures
45
import fnmatch
56
import glob
67
import hashlib
@@ -531,6 +532,36 @@ def safetensors_weights_iterator(
531532
yield name, param
532533

533534

535+
def multi_thread_safetensors_weights_iterator(
536+
hf_weights_files: list[str],
537+
use_tqdm_on_load: bool,
538+
max_workers: int = 4,
539+
) -> Generator[tuple[str, torch.Tensor], None, None]:
540+
"""Multi-Thread iterate over the weights in the model safetensor files."""
541+
542+
def _load_file(st_file: str):
543+
result = load_file(st_file, device="cpu")
544+
return result
545+
546+
with concurrent.futures.ThreadPoolExecutor(
547+
max_workers=max_workers) as executor:
548+
futures = [
549+
executor.submit(_load_file, st_file)
550+
for st_file in hf_weights_files
551+
]
552+
futures_iter = tqdm(
553+
concurrent.futures.as_completed(futures),
554+
total=len(hf_weights_files),
555+
desc="Multi-thread loading shards",
556+
disable=not enable_tqdm(use_tqdm_on_load),
557+
bar_format=_BAR_FORMAT,
558+
)
559+
560+
for future in futures_iter:
561+
state_dict = future.result()
562+
yield from state_dict.items()
563+
564+
534565
def runai_safetensors_weights_iterator(
535566
hf_weights_files: list[str],
536567
use_tqdm_on_load: bool,
@@ -611,6 +642,39 @@ def pt_weights_iterator(
611642
del state
612643

613644

645+
def multi_thread_pt_weights_iterator(
646+
hf_weights_files: list[str],
647+
use_tqdm_on_load: bool,
648+
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
649+
max_workers: int = 4,
650+
) -> Generator[tuple[str, torch.Tensor], None, None]:
651+
"""Multi-Thread iterate over the weights in the model bin/pt files."""
652+
653+
def _load_file(bin_file: str):
654+
return torch.load(bin_file,
655+
map_location=pt_load_map_location,
656+
weights_only=True)
657+
658+
with concurrent.futures.ThreadPoolExecutor(
659+
max_workers=max_workers) as executor:
660+
futures = [
661+
executor.submit(_load_file, bin_file)
662+
for bin_file in hf_weights_files
663+
]
664+
futures_iter = tqdm(
665+
concurrent.futures.as_completed(futures),
666+
total=len(hf_weights_files),
667+
desc="Multi-thread loading pt checkpoint shards",
668+
disable=not enable_tqdm(use_tqdm_on_load),
669+
bar_format=_BAR_FORMAT,
670+
)
671+
672+
for future in futures_iter:
673+
state = future.result()
674+
yield from state.items()
675+
del state
676+
677+
614678
def get_gguf_extra_tensor_names(
615679
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
616680
reader = gguf.GGUFReader(gguf_file)

0 commit comments

Comments
 (0)