Skip to content

DoRA uses lots of GPU VRAM due to fp32 upcasting #1692

@rationalism

Description

@rationalism

System Info

peft 0.10.0, transformers 4.40.1, Python 3.10 on Ubuntu 22.04

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Doing language model fine-tuning using QLoRA with DoRA, eg. fine-tuning Meta-Llama-8-70B with https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py, with target_modules set to include all linear layers, uses much more GPU VRAM than training with ordinary LoRA.

Expected behavior

Fine-tuning a QLoRA language model using DoRA, with adapters applied to all linear layers, takes up much more GPU VRAM than ordinary LoRA and OOMed my machine. I think the issue is this line:

mag_norm_scale = (magnitude / weight_norm).view(1, -1)

it looks like magnitude is in fp32, so the input vector x is upcast to fp32 when it gets returned as result_dora. If both MLP and attention layers are added to target_modules, that fp32 output then causes the next DoRA module (in the MLP layer) to get an fp32 vector as input. This then causes the dequantized weight matrix to get upcast to fp32:

weight = weight.to(x.dtype)

which means the algebra in _get_weight_norm is done in fp32:

def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:

which OOMs my machine. Adding a cast back to x.dtype here:

return result_dora

fixes the problem. (I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions