Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add support for optimum-quanto #2000

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Aug 9, 2024

This is unfinished, only pure implementations are provided.

Resolves #1997

TODOs:

  • Documentation
  • Tests (should work on CPU!)
  • Install optimum-quanto for CI (awaits quanto release that contains fix to persistence)
  • Verify that QuantoLoraConv2d works
  • Optional: DoRA support
  • Optional: Mixed adapter batches support
  • Cleaner implementation? For now, uses private attributes _data and _scales, overriding .data did not have any effect.

State of unit tests

Since quanto layers are subclasses of their respective torch equivalents, they will generally work with PEFT methods, even if not supported explicitly. E.g. BOFT will "just work" with quanto. However, some merging etc. won't work properly, as this requires special handling for quanto. Therefore, these tests are skipped.

It could be argued that we should explicitly raise when trying to use a non-supported method with quanto. However, we don't do that in general, as we assume that a subclass relationship should mean that the method works with that module. We could do strict checking of type (not subclass), but who knows how much existing code would break for no reason because of that.

Merging tests had to be relaxed, torch.allclose would require quite a high tolerance to pass. Therefore, instead now measure that correlation is > 0.97, which is more robust to outliers.

Moreover, a bunch of tests needed to be skipped, e.g. because quanto does not support deepcopy-ing, and either the PEFT functionaliy (layer replication) or the test itself depends on copying. Also, quanto does not allow to convert the dtype (like calling model.half()).

This is unfinished, only pure implementations are provided.

TODOs:

- [  ] Documentation
- [  ] Tests (should work on CPU!)
- [  ] Whether Conv2d works is not verified yet
- [  ] Optional: DoRA support
- [  ] Optional: Mixed adapter batches support
@BenjaminBossan
Copy link
Member Author

This is what I used for "testing" so far and the results look correct:

import torch
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoModelForCausalLM

torch.manual_seed(0)
inputs = torch.arange(5).view(-1, 1)
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").eval()
with torch.inference_mode():
    output_base = model(inputs).logits
print("output_base")
print(output_base[0, 0, :5])

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

with torch.inference_mode():
    output_quantized = model(inputs).logits
print("output_quantized")
print(output_quantized[0, 0, :5])

config = LoraConfig(r=64, lora_alpha=1280, lora_dropout=0.1, init_lora_weights=False)
print("adding adapter (random)")
model = get_peft_model(model, config)
model.eval()

with torch.inference_mode():
    output_lora = model(inputs).logits
    print("output_lora")
    print(output_lora[0, 0, :5])

    with model.disable_adapter():
        output_disabled = model(inputs).logits
        print("output_disabled")
        print(output_disabled[0, 0, :5])

    output_after_disabled = model(inputs).logits
    print("output_after_disabled")
    print(output_after_disabled[0, 0, :5])

model.merge_adapter()
with torch.inference_mode():
    output_merged = model(inputs).logits
print("output_merged")
print(output_merged[0, 0, :5])

model.unmerge_adapter()
with torch.inference_mode():
    output_unmerged = model(inputs).logits
print("output_unmerged")
print(output_unmerged[0, 0, :5])

unloaded = model.merge_and_unload()
with torch.inference_mode():
    output_unloaded = unloaded(inputs).logits
print("output_unloaded")
print(output_unloaded[0, 0, :5])

If someone wants to test this, they can checkout this branch or they can copy-paste the layer definitions and then dynamically dispatch to the new layers using the normal PEFT release:

from optimum.quanto import QConv2d, QLinear

# copy code for QuantoLoraLinear and QuantoLoraConv2d

custom_module_mapping = {QConv2d: QuantoLoraConv2d, QLinear: QuantoLoraLinear}
config = LoraConfig(...)
config._register_custom_module(custom_module_mapping)

@bghira
Copy link

bghira commented Aug 13, 2024

2024-08-12 19:29:58,243 [INFO] (SaveHookManager) Loading LoRA weights from Path: /Users/bghira/Training/flux/models/checkpoint-10
'time_text_embed.timestep_embedder.linear_1.weight._data'
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/train.py", line 2761, in <module>
    main()
  File "/Users/bghira/src/SimpleTuner/train.py", line 1566, in main
    accelerator.load_state(os.path.join(args.output_dir, path))
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 3098, in load_state
    hook(models, input_dir)
  File "/Users/bghira/src/SimpleTuner/helpers/training/save_hooks.py", line 416, in load_model_hook
    self._load_lora(models=models, input_dir=input_dir)
  File "/Users/bghira/src/SimpleTuner/helpers/training/save_hooks.py", line 335, in _load_lora
    incompatible_keys = set_peft_model_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 397, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

i pulled the new peft build from your branch and applied the mapping to the LoraConfig, but I still see this error when it comes time to loading the state dict. I think the problem is on the Diffusers side here. the set_peft_model_state_dict

@BenjaminBossan
Copy link
Member Author

Thanks for reporting this @bghira. I think it's an issue with optimum-quanto. I already reported this here.

@BenjaminBossan BenjaminBossan changed the title [WIP][FEAT] Add support for optimum-quanto [FEAT] Add support for optimum-quanto Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support optimum-quanto
2 participants