Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-strict loading of the state dict #278

Closed
BenjaminBossan opened this issue Aug 12, 2024 · 12 comments · Fixed by #295
Closed

Non-strict loading of the state dict #278

BenjaminBossan opened this issue Aug 12, 2024 · 12 comments · Fixed by #295

Comments

@BenjaminBossan
Copy link
Member

Hi, I'm currently investigating the addition of optimum-quanto to PEFT. This mostly works already but I'm hitting a wall when it comes to loading the state_dict. When loading a PEFT adapter like LoRA, we typically assume that the base model weights are already correctly loaded and are thus only interested in loading the adapter weights. That's why we're calling model.load_state_dict(peft_model_state_dict, strict=False) and ignore the missing keys.

Now when I try to do that with a quanto-model, I get an error about missing keys despite having strict=False. Below is a reproducer that does not involve PEFT for simplification:

from transformers import AutoModelForCausalLM
from optimum.quanto import quantize, qint8

model_id = "facebook/opt-125m"

# FIRST WITHOUT QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight")  # delete one item
# try with strict=True
try:
    model.load_state_dict(sd)
except RuntimeError as e:
    print(e)
# as expcted, prints:
# Error(s) in loading state_dict for OPTForCausalLM:
#	Missing key(s) in state_dict: "model.decoder.layers.0.self_attn.k_proj.weight".

# now strict=False
model.load_state_dict(sd, strict=False)
# passes and returns
# _IncompatibleKeys(missing_keys=['model.decoder.layers.0.self_attn.k_proj.weight'], unexpected_keys=[])

# SECOND WITH QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
quantize(model, weights=qint8)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight")

# try with strict=True
try:
    model.load_state_dict(sd)
except KeyError as e:  # KeyError, not RuntimeError
    print(e)
# prints:
# 'model.decoder.layers.0.self_attn.k_proj.weight._data'

# now strict=False
model.load_state_dict(sd, strict=False)
# same KeyError as with strict=True

Full error:

KeyError                                  Traceback (most recent call last)
Cell In[15], line 1
----> 1 model.load_state_dict(sd, strict=False)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2201, in Module.load_state_dict(self, state_dict, strict, assign)
   2194         out = hook(module, incompatible_keys)
   2195         assert out is None, (
   2196             "Hooks registered with ``register_load_state_dict_post_hook`` are not"
   2197             "expected to return new values, if incompatible_keys need to be modified,"
   2198             "it should be done inplace."
   2199         )
-> 2201 load(self, state_dict)
   2202 del load
   2204 if strict:

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

    [... skipping similar frames: Module.load_state_dict.<locals>.load at line 2189 (3 times)]

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2183, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2181 if assign:
   2182     local_metadata['assign_to_params_buffers'] = assign
-> 2183 module._load_from_state_dict(
   2184     local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
   2185 for name, child in module._modules.items():
   2186     if child is not None:

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/nn/qmodule.py:159, in QModuleMixin._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
    157 weight_prefix = weight_name + "."
    158 if self.weight_qtype.bits == 8:
--> 159     deserialized_weight = QBytesTensor.load_from_state_dict(
    160         state_dict,
    161         weight_prefix,
    162         qtype=self.weight_qtype,
    163         axis=0,
    164         size=self.weight.size(),
    165         stride=self.weight.stride(),
    166     )
    167 else:
    168     deserialized_weight = QBitsTensor.load_from_state_dict(
    169         state_dict,
    170         weight_prefix,
   (...)
    175         stride=self.weight.stride(),
    176     )

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/tensor/qbytes.py:90, in QBytesTensor.load_from_state_dict(state_dict, prefix, qtype, axis, size, stride)
     88 inner_tensors_dict = {}
     89 for name in ["_data", "_scale"]:
---> 90     inner_tensors_dict[name] = state_dict.pop(prefix + name)
     91 meta = {
     92     "qtype": qtype.name,
     93     "axis": str(axis),
     94     "size": str(list(size)),
     95     "stride": str(list(stride)),
     96 }
     97 return QBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)

KeyError: 'model.decoder.layers.0.self_attn.k_proj.weight._data'
@bghira
Copy link

bghira commented Aug 13, 2024

i dug into this yesterday and its a mystery how the state dict is empty at most stages of the module load in pytorch

@BenjaminBossan
Copy link
Member Author

Yeah, I agree that set_state_dict can be difficult to understand at times. E.g. I was surprised that when the sub-modules are loaded recursively, the strict argument is not even passed down (module._load_from_state_dict is always called with strict=True). Instead, missing_keys and unexpected_keys are passed and need to be mutated by the sub-modules by appending the corresponding lists. But I could not figure out how to add this to load_from_state_dict in quanto.

@bghira
Copy link

bghira commented Aug 13, 2024

yes, everything you just said, 100%

@bghira
Copy link

bghira commented Aug 13, 2024

cc @sayakpaul and @dacorvo

@sayakpaul
Copy link
Member

David would be the best to comment on the design and decision choices. But I did some micro investigations which I will note below.

If we do:

from optimum.quanto import quantize, qint8, freeze

model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
quantize(model, weights=qint8)

for k in model.state_dict():
    if "pooler" in k:
        print(k)

We see:

pooler.dense.weight
pooler.dense.bias
pooler.dense.input_scale
pooler.dense.output_scale

However, when we do:

sd = model.state_dict()
weight = sd.pop("pooler.dense.weight")

try:
    model.load_state_dict(sd)
except RuntimeError as e:
    print(e)

It complains about:

size, stride)
     88         inner_tensors_dict = {}
     89         for name in ["_data", "_scale"]:
---> 90             inner_tensors_dict[name] = state_dict.pop(prefix + name)
     91         meta = {
     92             "qtype": qtype.name,

KeyError: 'pooler.dense.weight._data'

This complaint, IMO, is weird because *._data tensors haven't been brought into the state dict itself. We can confirm this with:

any("_data" in k for k in sd)

When we additionally freeze(model), the *._data and *._scale related tensors are made available in the model, while the actual *.weight tensors pop out. They are reconstructed of sorts with their associated *._data and *._scale values.

Some keys don't go through quantization and we confirm that by:

for k in sd:
    if ".weight" in k and ("data" not in k and "scale" not in k):
        print(k

For my tiny model, that prints:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
encoder.layer.0.attention.output.LayerNorm.weight
encoder.layer.0.output.LayerNorm.weight
encoder.layer.1.attention.output.LayerNorm.weight
encoder.layer.1.output.LayerNorm.weight
encoder.layer.2.attention.output.LayerNorm.weight
encoder.layer.2.output.LayerNorm.weight
encoder.layer.3.attention.output.LayerNorm.weight
encoder.layer.3.output.LayerNorm.weight
encoder.layer.4.attention.output.LayerNorm.weight
encoder.layer.4.output.LayerNorm.weight

Not sure how useful this information is, but I just thought of noting it down here.

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Aug 14, 2024

This complaint, IMO, is weird because *._data tensors haven't been brought into the state dict itself.

I think it's more of a case of the error message being confusing, but I would expect an error here if not passing strict=False.

When we additionally freeze(model), the *._data and *._scale related tensors are made available in the model, while the actual *.weight tensors pop out.

That's really curious. I checked if it would help loading incomplete state dicts when passing strict=False but still no luck. I wonder how it works since freeze just calls this:

https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/nn/qmodule.py#L264-L268

Edit: I looks like David is OOO for still some time, not sure if there is anyone else we can ask.

@sayakpaul
Copy link
Member

That's really curious. I checked if it would help loading incomplete state dicts when passing strict=False but still no luck. I wonder how it works since freeze just calls this:

I would assume that materializes the quantization-affected parameters i.e., data and scale.

@dacorvo
Copy link
Collaborator

dacorvo commented Aug 14, 2024

@BenjaminBossan this branch should fix the very first issue (before calling freeze), but eventually if you need to omit some quantized weights, more developments are needed (see my second commit with hints).
https://github.com/huggingface/optimum-quanto/tree/load_strict_false

@BenjaminBossan
Copy link
Member Author

Thanks for the pointer @dacorvo. I made some changes to the 8bit code based on your comment using commit 832f7f5 as a basis (v0.2.4):

diff for tensor/qbytes.py:

87c87
<     def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride):
---
>     def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, missing_keys):
88a89
>         missing = False
90c91,99
<             inner_tensors_dict[name] = state_dict.pop(prefix + name)
---
>             if prefix + name not in state_dict:
>                 missing_keys.append(prefix + name)
>                 missing = True
>             else:
>                 inner_tensors_dict[name] = state_dict.pop(prefix + name)
> 
>         if missing:  # could not deserialize
>             return None
> 

diff for nn/qmodule.py:

165a166
>                     missing_keys=missing_keys,
180c181
<             if assign_to_params_buffers:
---
>             if (deserialized_weight is not None) and assign_to_params_buffers:
182c183
<             else:
---
>             elif deserialized_weight is not None:

With these changes, my script above passes (except that KeyError has to be replaced with RuntimeError, which is consistent with PyTorch).

The error message could be cleaned up, as it currently is:

Error(s) in loading state_dict for OPTForCausalLM:
Missing key(s) in state_dict: "model.decoder.layers.0.self_attn.k_proj.weight._data", "model.decoder.layers.0.self_attn.k_proj.weight._scale".

which could be confusing. We could just remove the suffix from the key but this would not account for the situation of only one of _data or _scale missing, so I just left it as is.

@bghira
Copy link

bghira commented Aug 21, 2024

@BenjaminBossan this works for me by the way

@BenjaminBossan
Copy link
Member Author

@dacorvo LMK if you think this is the way to go and if yes, whether I should create a PR.

@dacorvo
Copy link
Collaborator

dacorvo commented Aug 26, 2024

@BenjaminBossan yes this is what I had in mind: please submit a pull-request with your changes.

BenjaminBossan added a commit to BenjaminBossan/optimum-quanto that referenced this issue Aug 26, 2024
Resolves huggingface#278

PyTorch allows to load state dicts with they strict=False argument to
ignore missing keys. This is now also supported in optimum-quanto.
Before this fix, a KeyError would be raised.

One context where this is important is for parameter-efficient
fine-tuning adapters such as LoRA. There, we want to load only a small
subset of parameters and leave the other model weights untouched. This
requires non-strict loading.
BenjaminBossan added a commit to BenjaminBossan/optimum-quanto that referenced this issue Aug 26, 2024
Resolves huggingface#278

PyTorch allows to load state dicts with they strict=False argument to
ignore missing keys. This is now also supported in optimum-quanto.
Before this fix, a KeyError would be raised.

One context where this is important is for parameter-efficient
fine-tuning adapters such as LoRA. There, we want to load only a small
subset of parameters and leave the other model weights untouched. This
requires non-strict loading.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants