Skip to content

Commit

Permalink
[Bugfix] Fix torch dynamo fixes caused by replace_parameters (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson authored Sep 24, 2024
1 parent 2529d09 commit 72fc97a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/quantization/utils/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str,
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:

old = getattr(mod, name)
if old.dtype == new.dtype and \
if type(old) is type(new) and old.dtype == new.dtype and \
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace(old, new)
else:
# Fallback re-register parameter
# Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new)
mod.register_parameter(name, torch.nn.Parameter(new))
new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name,
torch.nn.Parameter(new, requires_grad=False))

0 comments on commit 72fc97a

Please sign in to comment.