-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to Qwen2-0.5B and Qwen2-1.5B. #1247
Conversation
# Conflicts: # torchtune/utils/_checkpointing/_checkpointer.py
Fix some bugs in weight converters.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1247
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 856555b with merge base 9fd5d01 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -0,0 +1,75 @@ | |||
# Config for multi-device full finetuning in full_finetune_distributed.py | |||
# using a Qwen2 0.5B model | |||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel a bit like we need to rethink tune ls
. Here's the output just now:
RECIPE CONFIG
full_finetune_single_device llama2/7B_full_low_memory
code_llama2/7B_full_low_memory
llama3/8B_full_single_device
llama3_1/8B_full_single_device
mistral/7B_full_low_memory
phi3/mini_full_low_memory
qwen2/7B_full_low_memory
full_finetune_distributed llama2/7B_full
llama2/13B_full
llama3/8B_full
llama3_1/8B_full
llama3/70B_full
llama3_1/70B_full
mistral/7B_full
gemma/2B_full
gemma/7B_full
phi3/mini_full
qwen2/7B_full
lora_finetune_single_device llama2/7B_lora_single_device
llama2/7B_qlora_single_device
code_llama2/7B_lora_single_device
code_llama2/7B_qlora_single_device
llama3/8B_lora_single_device
llama3_1/8B_lora_single_device
llama3/8B_qlora_single_device
llama3_1/8B_qlora_single_device
llama2/13B_qlora_single_device
mistral/7B_lora_single_device
mistral/7B_qlora_single_device
gemma/2B_lora_single_device
gemma/2B_qlora_single_device
gemma/7B_lora_single_device
gemma/7B_qlora_single_device
phi3/mini_lora_single_device
phi3/mini_qlora_single_device
qwen2/7B_lora_single_device
lora_dpo_single_device llama2/7B_lora_dpo_single_device
lora_dpo_distributed llama2/7B_lora_dpo
lora_finetune_distributed llama2/7B_lora
llama2/13B_lora
llama2/70B_lora
llama3/70B_lora
llama3_1/70B_lora
llama3/8B_lora
llama3_1/8B_lora
mistral/7B_lora
gemma/2B_lora
gemma/7B_lora
phi3/mini_lora
qwen2/7B_lora
lora_finetune_fsdp2 llama2/7B_lora
llama2/13B_lora
llama2/70B_lora
llama2/7B_qlora
llama2/70B_qlora
generate generation
eleuther_eval eleuther_evaluation
quantize quantization
qat_distributed llama2/7B_qat_full
llama3/8B_qat_full
Imo this isn't scaling well as we add more and more cool models and support configs for new techniques (imagine how many we'll have for multimodal!). It's getting unweildy.
Not sure if this is already on your radar @joecummings @ebsmothers but I have an idea or two (one radical, one not-so-radical) to address- happy to put an RFC up if there's consensus?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a good point. I know @joecummings had some ideas on this so will defer to him, but I think a quick RFC on how to scale tune ls
better would definitely be helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's chat - this is definitely on my radar. Glad you caught it, too.
@@ -68,7 +68,7 @@ def qwen2_hf_to_tune( | |||
|
|||
for key, value in state_dict.items(): | |||
if ( | |||
tie_word_embeddings and QWEN2_TIED_KEY not in key | |||
tie_word_embeddings and QWEN2_TIED_KEY in key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, happy to have this filed as a refactor cleanup to start grouping some of these together in a "TIED_MODEL" key or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I'm missing something, but do we need a refactor here? I thought we ended up not needing any special checkpointing logic for Gemma, so it's just a matter of applying the same changes as in https://github.com/pytorch/torchtune/pull/1168/files - removing qwen2/_convert_weights.py
, and remove any special logic in the checkpointers (here, for example), so it goes through the default model save/load path.
Happy to do this in a follow up, anyway, so this doesn't get blocked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! A few small comments here but this looks pretty good to me. Also the wandb link in the PR summary is broken for me, would definitely be interested to see the comparison with HF loss curves if that's possible
@@ -0,0 +1,75 @@ | |||
# Config for multi-device full finetuning in full_finetune_distributed.py | |||
# using a Qwen2 0.5B model | |||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a good point. I know @joecummings had some ideas on this so will defer to him, but I think a quick RFC on how to scale tune ls
better would definitely be helpful
Hi, I have updated loss curves between torchtune and HF in description. |
Thanks! Also bumping this comment @fyabc. Mainly I don't think we should need stuff like |
Thank you for your suggestions! I will update the related 0.5B and 1.5B recipes. |
Hi @fyabc! Anything we can do to help finish this up? We'd love to feature this on our README and post about it on our Discord channel, as well. |
Hi, I have updated this PR to resolve review comments. All I think this CR is ready to merge, and feel free for more suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work as always :)
Context
This PR:
TiedEmbeddingTransformerDecoder
to align with the defaultTransformerDecoder
.Loss curves of Qwen2-0.5B and 1.5B full-parameter finetune on alpaca clean dataset have been uploaded to wandb (loss comparison curves are shown below)
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Loss Curves
All loss curves are smoothed with
smooth_ratio=0.6
.wandb link: https://api.wandb.ai/links/fyabc-123/jxbvzezh
wandb link: https://api.wandb.ai/links/fyabc-123/w2lqv1v6