-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KD distributed recipe #1631
Changes from all commits
2933cd6
bff065a
9dd7b47
a39e99c
6dbcd38
0c4e4f9
380f267
da2b4bb
8beaca0
b54929a
b31c56d
fe5ed97
8b9ea41
3f7fe70
a87aa0c
f5feac4
8c3c42a
6ba0514
04ea649
62faa1d
bf15406
ac9eb0e
87a80b6
1fc3f64
106aa3e
0f4e922
22fddca
526a4dc
0bb49dc
c73857d
59eff44
85d76bb
dba57c4
04e2282
44123b9
a04244d
1ff9934
703e7dc
0031bfb
307791d
fefc24d
15c5be2
46473ee
2e212ec
557396e
4d376e3
53c47ba
f193d02
227e69d
cf5f01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
# using a teacher and student model | ||
# | ||
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed | ||
# | ||
# This config works best for distilling on 2+ devices. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_2.lora_llama3_2_1b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 64 | ||
lora_alpha: 128 | ||
lora_dropout: 0.0 | ||
|
||
teacher_model: | ||
_component_: torchtune.models.llama3_1.llama3_1_8b | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-3.2-1B-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
# Teacher checkpoint | ||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
checkpoint_files: [ | ||
model-00001-of-00004.safetensors, | ||
model-00002-of-00004.safetensors, | ||
model-00003-of-00004.safetensors, | ||
model-00004-of-00004.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
model_type: LLAMA3 | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 4 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
kd_loss: | ||
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
kd_ratio: 0.5 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 32 | ||
|
||
# Logging | ||
output_dir: /tmp/kd_output | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
|
||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
# using a teacher and student model | ||
# | ||
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None | ||
# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None | ||
# | ||
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed | ||
# | ||
# This config works best for distilling on 2+ devices. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.qwen2.lora_qwen2_0_5b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
lora_rank: 32 | ||
lora_alpha: 64 | ||
|
||
teacher_model: | ||
_component_: torchtune.models.qwen2.qwen2_1_5b | ||
|
||
tokenizer: | ||
_component_: torchtune.models.qwen2.qwen2_tokenizer | ||
path: /tmp/Qwen2-0.5B-Instruct/vocab.json | ||
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-0.5B-Instruct-kd | ||
model_type: QWEN2 | ||
|
||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
checkpoint_files: [ | ||
hf_model_0001_0.pt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the default If so, the first example (Llama3.2) is wrong b/c the default files are safetensors, which are only saved if the checkpointer specifies There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used the default lora distributed finetune configs for qwen2 and llama3.1 8b. I'm not sure why qwen2/1.5B_lora outputs |
||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
model_type: QWEN2 | ||
|
||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 8 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
kd_loss: | ||
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
kd_ratio: 0.5 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 2 | ||
|
||
# Logging | ||
output_dir: /tmp/qwen_kd | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
|
||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's also worthwhile to add a config with 70B model size? (Doesn't necessarily have to be in this PR, but it'd be useful to have at least one config that strictly requires distributed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow-up to that, I wonder if we should include model sizes in the config names? I know it makes it a bit longer (and doesn't line up with what you did for the single-device configs), but otherwise we cannot really distinguish between configs for distilling 70B -> 1B vs 8B -> 1B. Similar to the other comment here, this is fine to save for a follow-up though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this config on the 70B model and verified it works, but I think there has to be more tuning. We can add the 70B model in a separate PR and figure out how to change the naming. There wasn't many changes to add the 70B model, just the model target and checkpoint since tokenizer has to be the same right now.