-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
System Info
peft 0.10.0, transformers 4.40.1, Python 3.10 on Ubuntu 22.04
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder - My own task or dataset (give details below)
Reproduction
Doing language model fine-tuning using QLoRA with DoRA, eg. fine-tuning Meta-Llama-8-70B with https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py, with target_modules set to include all linear layers, uses much more GPU VRAM than training with ordinary LoRA.
Expected behavior
Fine-tuning a QLoRA language model using DoRA, with adapters applied to all linear layers, takes up much more GPU VRAM than ordinary LoRA and OOMed my machine. I think the issue is this line:
peft/src/peft/tuners/lora/layer.py
Line 238 in 608a90d
| mag_norm_scale = (magnitude / weight_norm).view(1, -1) |
it looks like magnitude is in fp32, so the input vector x is upcast to fp32 when it gets returned as result_dora. If both MLP and attention layers are added to target_modules, that fp32 output then causes the next DoRA module (in the MLP layer) to get an fp32 vector as input. This then causes the dequantized weight matrix to get upcast to fp32:
peft/src/peft/tuners/lora/layer.py
Line 229 in 608a90d
| weight = weight.to(x.dtype) |
which means the algebra in _get_weight_norm is done in fp32:
peft/src/peft/tuners/lora/layer.py
Line 176 in 608a90d
| def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: |
which OOMs my machine. Adding a cast back to x.dtype here:
peft/src/peft/tuners/lora/layer.py
Line 253 in 608a90d
| return result_dora |
fixes the problem. (I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes)