Skip to content

Commit

Permalink
better fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc authored and dacorvo committed Mar 19, 2024
1 parent ef455a5 commit 29772f2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def deserialize_tensor_subclass(t, state_dict, prefix):
# FIXME: here we should copy frozen weights into frozen module, but this leads to grad error
self.weight = torch.nn.Parameter(deserialized_weight.to(device))
# this is needed because we can't load it correctly when the bias is on the meta device
if state_dict.get(prefix + "bias", False):
if prefix + "bias" in state_dict:
self.bias = torch.nn.Parameter(state_dict.pop(prefix + "bias"))
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs
Expand Down

0 comments on commit 29772f2

Please sign in to comment.