-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-strict loading of the state dict #278
Comments
i dug into this yesterday and its a mystery how the state dict is empty at most stages of the module load in pytorch |
Yeah, I agree that |
yes, everything you just said, 100% |
cc @sayakpaul and @dacorvo |
David would be the best to comment on the design and decision choices. But I did some micro investigations which I will note below. If we do: from optimum.quanto import quantize, qint8, freeze
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
quantize(model, weights=qint8)
for k in model.state_dict():
if "pooler" in k:
print(k) We see: pooler.dense.weight
pooler.dense.bias
pooler.dense.input_scale
pooler.dense.output_scale However, when we do: sd = model.state_dict()
weight = sd.pop("pooler.dense.weight")
try:
model.load_state_dict(sd)
except RuntimeError as e:
print(e) It complains about: size, stride)
88 inner_tensors_dict = {}
89 for name in ["_data", "_scale"]:
---> 90 inner_tensors_dict[name] = state_dict.pop(prefix + name)
91 meta = {
92 "qtype": qtype.name,
KeyError: 'pooler.dense.weight._data' This complaint, IMO, is weird because any("_data" in k for k in sd) When we additionally Some keys don't go through quantization and we confirm that by: for k in sd:
if ".weight" in k and ("data" not in k and "scale" not in k):
print(k For my tiny model, that prints: embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
encoder.layer.0.attention.output.LayerNorm.weight
encoder.layer.0.output.LayerNorm.weight
encoder.layer.1.attention.output.LayerNorm.weight
encoder.layer.1.output.LayerNorm.weight
encoder.layer.2.attention.output.LayerNorm.weight
encoder.layer.2.output.LayerNorm.weight
encoder.layer.3.attention.output.LayerNorm.weight
encoder.layer.3.output.LayerNorm.weight
encoder.layer.4.attention.output.LayerNorm.weight
encoder.layer.4.output.LayerNorm.weight Not sure how useful this information is, but I just thought of noting it down here. |
I think it's more of a case of the error message being confusing, but I would expect an error here if not passing
That's really curious. I checked if it would help loading incomplete state dicts when passing Edit: I looks like David is OOO for still some time, not sure if there is anyone else we can ask. |
I would assume that materializes the quantization-affected parameters i.e., data and scale. |
@BenjaminBossan this branch should fix the very first issue (before calling freeze), but eventually if you need to omit some quantized weights, more developments are needed (see my second commit with hints). |
Thanks for the pointer @dacorvo. I made some changes to the 8bit code based on your comment using commit 832f7f5 as a basis (v0.2.4): diff for 87c87
< def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride):
---
> def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, missing_keys):
88a89
> missing = False
90c91,99
< inner_tensors_dict[name] = state_dict.pop(prefix + name)
---
> if prefix + name not in state_dict:
> missing_keys.append(prefix + name)
> missing = True
> else:
> inner_tensors_dict[name] = state_dict.pop(prefix + name)
>
> if missing: # could not deserialize
> return None
> diff for 165a166
> missing_keys=missing_keys,
180c181
< if assign_to_params_buffers:
---
> if (deserialized_weight is not None) and assign_to_params_buffers:
182c183
< else:
---
> elif deserialized_weight is not None: With these changes, my script above passes (except that The error message could be cleaned up, as it currently is:
which could be confusing. We could just remove the suffix from the key but this would not account for the situation of only one of |
@BenjaminBossan this works for me by the way |
@dacorvo LMK if you think this is the way to go and if yes, whether I should create a PR. |
@BenjaminBossan yes this is what I had in mind: please submit a pull-request with your changes. |
Resolves huggingface#278 PyTorch allows to load state dicts with they strict=False argument to ignore missing keys. This is now also supported in optimum-quanto. Before this fix, a KeyError would be raised. One context where this is important is for parameter-efficient fine-tuning adapters such as LoRA. There, we want to load only a small subset of parameters and leave the other model weights untouched. This requires non-strict loading.
Resolves huggingface#278 PyTorch allows to load state dicts with they strict=False argument to ignore missing keys. This is now also supported in optimum-quanto. Before this fix, a KeyError would be raised. One context where this is important is for parameter-efficient fine-tuning adapters such as LoRA. There, we want to load only a small subset of parameters and leave the other model weights untouched. This requires non-strict loading.
Hi, I'm currently investigating the addition of optimum-quanto to PEFT. This mostly works already but I'm hitting a wall when it comes to loading the
state_dict
. When loading a PEFT adapter like LoRA, we typically assume that the base model weights are already correctly loaded and are thus only interested in loading the adapter weights. That's why we're callingmodel.load_state_dict(peft_model_state_dict, strict=False)
and ignore the missing keys.Now when I try to do that with a quanto-model, I get an error about missing keys despite having
strict=False
. Below is a reproducer that does not involve PEFT for simplification:Full error:
The text was updated successfully, but these errors were encountered: