|
10 | 10 | import itertools |
11 | 11 | import math |
12 | 12 | import os |
| 13 | +import time |
13 | 14 | import warnings |
14 | 15 | from abc import ABC, abstractmethod |
15 | 16 | from contextlib import contextmanager |
@@ -216,6 +217,9 @@ class Source: |
216 | 217 | allow_patterns_overrides: Optional[list[str]] = None |
217 | 218 | """If defined, weights will load exclusively using these patterns.""" |
218 | 219 |
|
| 220 | + counter_before_loading_weights: float = 0.0 |
| 221 | + counter_after_loading_weights: float = 0.0 |
| 222 | + |
219 | 223 | def __init__(self, load_config: LoadConfig): |
220 | 224 | super().__init__(load_config) |
221 | 225 | if load_config.model_loader_extra_config: |
@@ -368,6 +372,8 @@ def _xla_weights_iterator(iterator: Generator): |
368 | 372 |
|
369 | 373 | weights_iterator = _xla_weights_iterator(weights_iterator) |
370 | 374 |
|
| 375 | + if self.counter_before_loading_weights == 0.0: |
| 376 | + self.counter_before_loading_weights = time.perf_counter() |
371 | 377 | # Apply the prefix. |
372 | 378 | return ((source.prefix + name, tensor) |
373 | 379 | for (name, tensor) in weights_iterator) |
@@ -412,6 +418,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: |
412 | 418 | weights_to_load = {name for name, _ in model.named_parameters()} |
413 | 419 | loaded_weights = model.load_weights( |
414 | 420 | self._get_all_weights(model_config, model)) |
| 421 | + self.counter_after_loading_weights = time.perf_counter() |
| 422 | + logger.info( |
| 423 | + "Loading weights took %.2f seconds", |
| 424 | + self.counter_after_loading_weights - |
| 425 | + self.counter_before_loading_weights) |
415 | 426 | # We only enable strict check for non-quantized models |
416 | 427 | # that have loaded weights tracking currently. |
417 | 428 | if model_config.quantization is None and loaded_weights is not None: |
|
0 commit comments