Skip to content

Commit

Permalink
[misc][distributed] add seed to dummy weights (vllm-project#6491)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and fialhocoelho committed Jul 19, 2024
1 parent 572aff2 commit 7668b56
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,24 +440,33 @@ def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
tmp_param = tmp_param.uniform_(low, high,
generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)
param.uniform_(low, high, generator=generator)


def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
Expand Down

0 comments on commit 7668b56

Please sign in to comment.