1818 download_safetensors_index_file_from_hf , download_weights_from_hf ,
1919 fastsafetensors_weights_iterator , filter_duplicate_safetensors_files ,
2020 filter_files_not_needed_for_inference , maybe_download_from_modelscope ,
21- np_cache_weights_iterator , pt_weights_iterator ,
22- safetensors_weights_iterator )
21+ multi_thread_pt_weights_iterator ,
22+ multi_thread_safetensors_weights_iterator , np_cache_weights_iterator ,
23+ pt_weights_iterator , safetensors_weights_iterator )
2324from vllm .platforms import current_platform
2425
2526logger = init_logger (__name__ )
2829class DefaultModelLoader (BaseModelLoader ):
2930 """Model loader that can load different file types from disk."""
3031
32+ # default number of thread when enable multithread weight loading
33+ DEFAULT_NUM_THREADS = 8
34+
3135 @dataclasses .dataclass
3236 class Source :
3337 """A source for weights."""
@@ -52,9 +56,15 @@ class Source:
5256
5357 def __init__ (self , load_config : LoadConfig ):
5458 super ().__init__ (load_config )
55- if load_config .model_loader_extra_config :
56- raise ValueError (f"Model loader extra config is not supported for "
57- f"load format { load_config .load_format } " )
59+
60+ extra_config = load_config .model_loader_extra_config
61+ allowed_keys = {"enable_multithread_load" , "num_threads" }
62+ unexpected_keys = set (extra_config .keys ()) - allowed_keys
63+
64+ if unexpected_keys :
65+ raise ValueError (f"Unexpected extra config keys for load format "
66+ f"{ load_config .load_format } : "
67+ f"{ unexpected_keys } " )
5868
5969 def _prepare_weights (
6070 self ,
@@ -145,6 +155,7 @@ def _get_weights_iterator(
145155 self , source : "Source"
146156 ) -> Generator [tuple [str , torch .Tensor ], None , None ]:
147157 """Get an iterator for the model weights based on the load format."""
158+ extra_config = self .load_config .model_loader_extra_config
148159 hf_folder , hf_weights_files , use_safetensors = self ._prepare_weights (
149160 source .model_or_path , source .revision , source .fall_back_to_pt ,
150161 source .allow_patterns_overrides )
@@ -165,16 +176,34 @@ def _get_weights_iterator(
165176 self .load_config .use_tqdm_on_load ,
166177 )
167178 else :
168- weights_iterator = safetensors_weights_iterator (
179+ if extra_config .get ("enable_multithread_load" ):
180+ weights_iterator = (
181+ multi_thread_safetensors_weights_iterator (
182+ hf_weights_files ,
183+ self .load_config .use_tqdm_on_load ,
184+ max_workers = extra_config .get (
185+ "num_threads" , self .DEFAULT_NUM_THREADS ),
186+ ))
187+ else :
188+ weights_iterator = safetensors_weights_iterator (
189+ hf_weights_files ,
190+ self .load_config .use_tqdm_on_load ,
191+ )
192+ else :
193+ if extra_config .get ("enable_multithread_load" ):
194+ weights_iterator = multi_thread_pt_weights_iterator (
169195 hf_weights_files ,
170196 self .load_config .use_tqdm_on_load ,
197+ self .load_config .pt_load_map_location ,
198+ max_workers = extra_config .get ("num_threads" ,
199+ self .DEFAULT_NUM_THREADS ),
200+ )
201+ else :
202+ weights_iterator = pt_weights_iterator (
203+ hf_weights_files ,
204+ self .load_config .use_tqdm_on_load ,
205+ self .load_config .pt_load_map_location ,
171206 )
172- else :
173- weights_iterator = pt_weights_iterator (
174- hf_weights_files ,
175- self .load_config .use_tqdm_on_load ,
176- self .load_config .pt_load_map_location ,
177- )
178207
179208 if current_platform .is_tpu ():
180209 from vllm .platforms .tpu import USE_TPU_COMMONS
0 commit comments