Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 qlora load_state_dict takes forever #1246

Closed
l-berg opened this issue Jul 30, 2024 · 4 comments
Closed

Llama3 qlora load_state_dict takes forever #1246

l-berg opened this issue Jul 30, 2024 · 4 comments
Assignees

Comments

@l-berg
Copy link

l-berg commented Jul 30, 2024

When working on my customized LoRAFinetuneRecipeSingleDevice recipe and upgrading from torchtune version 0.1.1 to 0.2.1 and torchao 0.1 to 0.3.1, I noticed that model loading times went up dramatically when using QLoRA. Now, loading llama3-8b takes about 5 minutes, where it used to only be a few seconds in version 0.1.1. I was able to pinpoint it to the call model.load_state_dict(base_model_state_dict, strict=False).

Here are the steps I took to reproduce this issue in a new conda environment (pytorch version 2.4).

conda create --name tune python=3.11
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
conda activate tune
pip install torchtune
tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir . --hf-token <token>
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR

outputs

...
DEBUG:torchtune.utils.logging:Setting manual seed to local seed 4061266822. Local seed is seed + rank = 4061266822 + 0
Writing logs to /.../testtune/instruct/original/log_1722353553.txt
<WAITING FOR 5 MINUTES, RAM usage slowly rising to 15GB>
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory stats after model init:
        GPU peak memory allocation: 6.79 GB
        GPU peak memory reserved: 6.97 GB
        GPU peak memory active: 6.79 GB
...

When I downgrade to version 0.1.1, it's fast again:

pip install torchtune==0.1.1
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR

outputs

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1435681341. Local seed is seed + rank = 1435681341 + 0
Writing logs to /.../testtune/instruct/original/log_1722354102.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory Stats after model init:
{'peak_memory_active': 10.129489408, 'peak_memory_alloc': 10.129489408, 'peak_memory_reserved': 12.639535104}

after only a few seconds.

I am on a slurm machine with 4 cpu threads and a nvidia 3090 24gb.

Any ideas on what might be the cause? The lora recipes without quantization work just fine.

@joecummings joecummings self-assigned this Aug 1, 2024
@gau-nernst
Copy link
Contributor

Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.

@joecummings
Copy link
Contributor

Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.

Makes sense, this was suggested by @msaroufim, too. I will confirm.

@l-berg Apologies for the late response - do you notice the slowdown on 0.2.0 as well? This will help me narrow down where these changes could be coming from.

@l-berg
Copy link
Author

l-berg commented Aug 8, 2024

Yes, upgrading from 0.1.1 to 0.2.0 results in the same increase from ~10s to >5min loading time on my machine.

@joecummings
Copy link
Contributor

Hi @l-berg - thanks for bringing this to our attention! The AO folks dug deep into this and saw that a version guarded inplace_copy function was the offending issue. Please read more about it here: pytorch/ao#642.

This will be fixed by @ebsmothers in #1294

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants