-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable LoRA + FSDP2 #855
enable LoRA + FSDP2 #855
Changes from 57 commits
e5826a1
64fc870
0cd21c6
589191e
c801f26
19a2d70
441da10
750b9e5
3d632d5
cb3abb3
e68804a
d6af9a2
b616394
a400497
e9de63c
05d3895
7a5bb80
64bf49c
cb1bba4
ac516e9
bfde704
102db31
0b66651
672aabb
6af2723
42ad99c
74f6175
f1b8a5e
36e6829
08cd1fd
559bc4d
2333134
49a0364
dc2ce02
0a604aa
fa83140
4b5a895
a2e34ec
6142031
7607e14
1899beb
c1cfabb
d7382ae
d1ff53b
1eb9e87
695e959
e10f638
b1e3d30
944a723
ac5f7aa
d769626
f90c3cc
42ef49a
170de94
f8a7018
a3b2f3e
1a692b3
8fbbc4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Config for multi-device LoRA in lora_finetune_distributed.py | ||
# using a Llama2 13B model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can update this header to mention that this config is for the recipe using FSDP2 (I know the config file is the same, but nice visibility to just explicitly call it out at the top of the file) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! updated to lora_finetune_fsdp2 and mentioned FSDP2 |
||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Llama-2-13b-hf --output-dir /tmp/Llama-2-13b-hf --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably need to do a find and replace of |
||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
# For single device LoRA finetuning please use 7B_lora_single_device.yaml | ||
# or 7B_qlora_single_device.yaml and update the model and checkpoints to | ||
# the 13B model. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.lora_llama2_13b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: True | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-2-13b-hf/ | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-2-13b-hf/ | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/Llama-2-13b-hf/tokenizer.model | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 2 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 2e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
|
||
# Logging | ||
output_dir: /tmp/lora_finetune_output | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Config for multi-device LoRA in lora_finetune_distributed.py | ||
# using a Llama2 70B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token <HF_TOKEN> | ||
# | ||
# This config needs 8 GPUs to run | ||
# # tune run --nproc_per_node 8 lora_finetune_distributed --config llama2/70B_lora | ||
# | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.lora_llama2_70b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 16 | ||
lora_alpha: 32 | ||
|
||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/Llama-2-70b-hf/tokenizer.model | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-2-70b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00015.bin, | ||
pytorch_model-00002-of-00015.bin, | ||
pytorch_model-00003-of-00015.bin, | ||
pytorch_model-00004-of-00015.bin, | ||
pytorch_model-00005-of-00015.bin, | ||
pytorch_model-00006-of-00015.bin, | ||
pytorch_model-00007-of-00015.bin, | ||
pytorch_model-00008-of-00015.bin, | ||
pytorch_model-00009-of-00015.bin, | ||
pytorch_model-00010-of-00015.bin, | ||
pytorch_model-00011-of-00015.bin, | ||
pytorch_model-00012-of-00015.bin, | ||
pytorch_model-00013-of-00015.bin, | ||
pytorch_model-00014-of-00015.bin, | ||
pytorch_model-00015-of-00015.bin, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-2-70b-hf | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 2 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
# Logging | ||
output_dir: /tmp/lora_finetune_output | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Config for multi-device LoRA finetuning in lora_finetune_distributed.py | ||
# using a Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
# For single device LoRA finetuning please use 7B_lora_single_device.yaml | ||
# or 7B_qlora_single_device.yaml | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.lora_llama2_7b | ||
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/Llama-2-7b-hf/tokenizer.model | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-2-7b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00002.bin, | ||
pytorch_model-00002-of-00002.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-2-7b-hf | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 2 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 32 | ||
|
||
# Logging | ||
output_dir: /tmp/lora_finetune_output | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from torch.testing._internal.common_utils import run_tests
has a depdency onpytest==7.4.0
andexpecttest
, borrowed from pytorch repoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
run_tests
strictly required for the usage ofFSDPTest
, or is it more used for convenience? (Either way not a huge issue)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's strictly required for the usage of
FSDPTest