Skip to content

Commit a6149aa

Browse files
authored
[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
1 parent 6c8a3c0 commit a6149aa

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

vllm/model_executor/parameter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.distributed import (get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size)
1414
from vllm.logger import init_logger
15-
from vllm.model_executor.utils import _make_synced_weight_loader
1615

1716
__all__ = [
1817
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
@@ -53,8 +52,9 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable):
5352
# This sometimes causes OOM errors during model loading. To avoid this,
5453
# we sync the param tensor after its weight loader is called.
5554
from vllm.platforms import current_platform
56-
if current_platform.is_tpu():
57-
weight_loader = _make_synced_weight_loader(weight_loader)
55+
if current_platform.use_sync_weight_loader():
56+
weight_loader = current_platform.make_synced_weight_loader(
57+
weight_loader)
5858

5959
self._weight_loader = weight_loader
6060
self.tp_rank = get_tensor_model_parallel_rank()

vllm/model_executor/utils.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,12 @@ def set_weight_attrs(
4444
# TODO(woosuk): Remove this hack once we have a better solution.
4545
from vllm.platforms import current_platform
4646

47-
if current_platform.is_tpu() and key == "weight_loader":
48-
value = _make_synced_weight_loader(value)
47+
if current_platform.use_sync_weight_loader(
48+
) and key == "weight_loader":
49+
value = current_platform.make_synced_weight_loader(value)
4950
setattr(weight, key, value)
5051

5152

52-
def _make_synced_weight_loader(original_weight_loader):
53-
54-
def _synced_weight_loader(param, *args, **kwargs):
55-
out = original_weight_loader(param, *args, **kwargs)
56-
# torch._sync doesn't support, is not needed for CPU tensors.
57-
if param.device != torch.device("cpu"):
58-
torch._sync(param)
59-
return out
60-
61-
return _synced_weight_loader
62-
63-
6453
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
6554
parent_map = getattr(model, "packed_modules_mapping", None)
6655
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}

vllm/platforms/interface.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,29 @@ def support_hybrid_kv_cache(cls) -> bool:
594594
"""
595595
return False
596596

597+
@classmethod
598+
def use_sync_weight_loader(cls) -> bool:
599+
"""
600+
Returns if the current platform needs to sync weight loader.
601+
"""
602+
return False
603+
604+
@classmethod
605+
def make_synced_weight_loader(cls, original_weight_loader):
606+
"""
607+
Wrap the original weight loader to make it synced.
608+
"""
609+
if not cls.use_sync_weight_loader():
610+
return original_weight_loader
611+
612+
def _synced_weight_loader(param, *args, **kwargs):
613+
out = original_weight_loader(param, *args, **kwargs)
614+
if param.device != torch.device("cpu"):
615+
torch._sync(param)
616+
return out
617+
618+
return _synced_weight_loader
619+
597620

598621
class UnspecifiedPlatform(Platform):
599622
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/tpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def swap_out_blocks_to_host(
226226
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
227227
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
228228

229+
@classmethod
230+
def use_sync_weight_loader(cls) -> bool:
231+
return True
232+
229233

230234
try:
231235
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform

0 commit comments

Comments
 (0)