|
49 | 49 | set_default_torch_dtype) |
50 | 50 | from vllm.model_executor.model_loader.weight_utils import ( |
51 | 51 | download_safetensors_index_file_from_hf, download_weights_from_hf, |
52 | | - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, |
53 | | - get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator, |
54 | | - initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, |
| 52 | + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, |
| 53 | + filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, |
| 54 | + get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, |
| 55 | + np_cache_weights_iterator, pt_weights_iterator, |
55 | 56 | runai_safetensors_weights_iterator, safetensors_weights_iterator) |
56 | 57 | from vllm.model_executor.utils import set_weight_attrs |
57 | 58 | from vllm.platforms import current_platform |
@@ -275,7 +276,8 @@ def _prepare_weights( |
275 | 276 | # Some quantized models use .pt files for storing the weights. |
276 | 277 | if load_format == LoadFormat.AUTO: |
277 | 278 | allow_patterns = ["*.safetensors", "*.bin"] |
278 | | - elif load_format == LoadFormat.SAFETENSORS: |
| 279 | + elif (load_format == LoadFormat.SAFETENSORS |
| 280 | + or load_format == LoadFormat.FASTSAFETENSORS): |
279 | 281 | use_safetensors = True |
280 | 282 | allow_patterns = ["*.safetensors"] |
281 | 283 | elif load_format == LoadFormat.MISTRAL: |
@@ -357,10 +359,16 @@ def _get_weights_iterator( |
357 | 359 | self.load_config.use_tqdm_on_load, |
358 | 360 | ) |
359 | 361 | elif use_safetensors: |
360 | | - weights_iterator = safetensors_weights_iterator( |
361 | | - hf_weights_files, |
362 | | - self.load_config.use_tqdm_on_load, |
363 | | - ) |
| 362 | + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: |
| 363 | + weights_iterator = fastsafetensors_weights_iterator( |
| 364 | + hf_weights_files, |
| 365 | + self.load_config.use_tqdm_on_load, |
| 366 | + ) |
| 367 | + else: |
| 368 | + weights_iterator = safetensors_weights_iterator( |
| 369 | + hf_weights_files, |
| 370 | + self.load_config.use_tqdm_on_load, |
| 371 | + ) |
364 | 372 | else: |
365 | 373 | weights_iterator = pt_weights_iterator( |
366 | 374 | hf_weights_files, |
|
0 commit comments