Skip to content

Commit 875493f

Browse files
manoelmarqueshmellorProExpertProg
authored andcommitted
Generate _ModelInfo properties file when loading to improve loading speed (vllm-project#23558)
Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 0d23d00 commit 875493f

File tree

4 files changed

+167
-3
lines changed

4 files changed

+167
-3
lines changed

vllm/logging_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from vllm.logging_utils.formatter import NewLineFormatter
5+
from vllm.logging_utils.log_time import logtime
56

67
__all__ = [
78
"NewLineFormatter",
9+
"logtime",
810
]

vllm/logging_utils/log_time.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Provides a timeslice logging decorator
5+
"""
6+
7+
import functools
8+
import time
9+
10+
11+
def logtime(logger, msg=None):
12+
"""
13+
Logs the execution time of the decorated function.
14+
Always place it beneath other decorators.
15+
"""
16+
17+
def _inner(func):
18+
19+
@functools.wraps(func)
20+
def _wrapper(*args, **kwargs):
21+
start = time.perf_counter()
22+
result = func(*args, **kwargs)
23+
elapsed = time.perf_counter() - start
24+
25+
prefix = f"Function '{func.__module__}.{func.__qualname__}'" \
26+
if msg is None else msg
27+
logger.debug("%s: Elapsed time %.7f secs", prefix, elapsed)
28+
return result
29+
30+
return _wrapper
31+
32+
return _inner

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import time
1212
from collections import defaultdict
1313
from collections.abc import Generator
14+
from contextlib import contextmanager
1415
from pathlib import Path
1516
from typing import Any, Callable, Optional, Union
1617

@@ -98,6 +99,49 @@ def get_lock(model_name_or_path: Union[str, Path],
9899
return lock
99100

100101

102+
@contextmanager
103+
def atomic_writer(filepath: Union[str, Path],
104+
mode: str = 'w',
105+
encoding: Optional[str] = None):
106+
"""
107+
Context manager that provides an atomic file writing routine.
108+
109+
The context manager writes to a temporary file and, if successful,
110+
atomically replaces the original file.
111+
112+
Args:
113+
filepath (str or Path): The path to the file to write.
114+
mode (str): The file mode for the temporary file (e.g., 'w', 'wb').
115+
encoding (str): The encoding for text mode.
116+
117+
Yields:
118+
file object: A handle to the temporary file.
119+
"""
120+
# Create a temporary file in the same directory as the target file
121+
# to ensure it's on the same filesystem for an atomic replace.
122+
temp_dir = os.path.dirname(filepath)
123+
temp_fd, temp_path = tempfile.mkstemp(dir=temp_dir)
124+
125+
try:
126+
# Open the temporary file for writing
127+
with os.fdopen(temp_fd, mode=mode, encoding=encoding) as temp_file:
128+
yield temp_file
129+
130+
# If the 'with' block completes successfully,
131+
# perform the atomic replace.
132+
os.replace(temp_path, filepath)
133+
134+
except Exception:
135+
logger.exception(
136+
"Error during atomic write. Original file '%s' not modified",
137+
filepath)
138+
raise
139+
finally:
140+
# Clean up the temporary file if it still exists.
141+
if os.path.exists(temp_path):
142+
os.remove(temp_path)
143+
144+
101145
def maybe_download_from_modelscope(
102146
model: str,
103147
revision: Optional[str] = None,

vllm/model_executor/models/registry.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,29 @@
44
Whenever you add an architecture to this page, please also update
55
`tests/models/registry.py` with example HuggingFace models for it.
66
"""
7+
import hashlib
78
import importlib
9+
import json
810
import os
911
import pickle
1012
import subprocess
1113
import sys
1214
import tempfile
1315
from abc import ABC, abstractmethod
1416
from collections.abc import Set
15-
from dataclasses import dataclass, field
17+
from dataclasses import asdict, dataclass, field
1618
from functools import lru_cache
19+
from pathlib import Path
1720
from typing import Callable, Optional, TypeVar, Union
1821

1922
import torch.nn as nn
2023
import transformers
2124

25+
from vllm import envs
2226
from vllm.config import (ModelConfig, iter_architecture_defaults,
2327
try_match_architecture_defaults)
2428
from vllm.logger import init_logger
29+
from vllm.logging_utils import logtime
2530
from vllm.transformers_utils.dynamic_module import (
2631
try_get_class_from_dynamic_module)
2732

@@ -421,10 +426,91 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
421426
module_name: str
422427
class_name: str
423428

424-
# Performed in another process to avoid initializing CUDA
429+
@staticmethod
430+
def _get_cache_dir() -> Path:
431+
return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"
432+
433+
def _get_cache_filename(self) -> str:
434+
cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
435+
return f"{cls_name}.json"
436+
437+
def _load_modelinfo_from_cache(self,
438+
module_hash: str) -> _ModelInfo | None:
439+
try:
440+
try:
441+
modelinfo_path = self._get_cache_dir(
442+
) / self._get_cache_filename()
443+
with open(modelinfo_path, encoding="utf-8") as file:
444+
mi_dict = json.load(file)
445+
except FileNotFoundError:
446+
logger.debug(("Cached model info file "
447+
"for class %s.%s not found"), self.module_name,
448+
self.class_name)
449+
return None
450+
451+
if mi_dict["hash"] != module_hash:
452+
logger.debug(("Cached model info file "
453+
"for class %s.%s is stale"), self.module_name,
454+
self.class_name)
455+
return None
456+
457+
# file not changed, use cached _ModelInfo properties
458+
return _ModelInfo(**mi_dict["modelinfo"])
459+
except Exception:
460+
logger.exception(("Cached model info "
461+
"for class %s.%s error. "), self.module_name,
462+
self.class_name)
463+
return None
464+
465+
def _save_modelinfo_to_cache(self, mi: _ModelInfo,
466+
module_hash: str) -> None:
467+
"""save dictionary json file to cache"""
468+
from vllm.model_executor.model_loader.weight_utils import atomic_writer
469+
try:
470+
modelinfo_dict = {
471+
"hash": module_hash,
472+
"modelinfo": asdict(mi),
473+
}
474+
cache_dir = self._get_cache_dir()
475+
cache_dir.mkdir(parents=True, exist_ok=True)
476+
modelinfo_path = cache_dir / self._get_cache_filename()
477+
with atomic_writer(modelinfo_path, encoding='utf-8') as f:
478+
json.dump(modelinfo_dict, f, indent=2)
479+
except Exception:
480+
logger.exception("Error saving model info cache.")
481+
482+
@logtime(logger=logger, msg="Registry inspect model class")
425483
def inspect_model_cls(self) -> _ModelInfo:
426-
return _run_in_subprocess(
484+
model_path = Path(
485+
__file__).parent / f"{self.module_name.split('.')[-1]}.py"
486+
487+
assert model_path.exists(), \
488+
f"Model {self.module_name} expected to be on path {model_path}"
489+
with open(model_path, "rb") as f:
490+
module_hash = hashlib.md5(f.read()).hexdigest()
491+
492+
mi = self._load_modelinfo_from_cache(module_hash)
493+
if mi is not None:
494+
logger.debug(("Loaded model info "
495+
"for class %s.%s from cache"), self.module_name,
496+
self.class_name)
497+
return mi
498+
else:
499+
logger.debug(("Cache model info "
500+
"for class %s.%s miss. "
501+
"Loading model instead."), self.module_name,
502+
self.class_name)
503+
504+
# Performed in another process to avoid initializing CUDA
505+
mi = _run_in_subprocess(
427506
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
507+
logger.debug("Loaded model info for class %s.%s", self.module_name,
508+
self.class_name)
509+
510+
# save cache file
511+
self._save_modelinfo_to_cache(mi, module_hash)
512+
513+
return mi
428514

429515
def load_model_cls(self) -> type[nn.Module]:
430516
mod = importlib.import_module(self.module_name)

0 commit comments

Comments
 (0)