Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KD distributed recipe #1631

Merged
merged 50 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2933cd6
sft recipes to eval kd
lindawangg Sep 5, 2024
bff065a
setup kd files
lindawangg Sep 5, 2024
9dd7b47
delete test config
lindawangg Sep 5, 2024
a39e99c
added student config
lindawangg Sep 6, 2024
6dbcd38
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 6, 2024
0c4e4f9
added teacher model loading
lindawangg Sep 6, 2024
380f267
added loss
lindawangg Sep 7, 2024
da2b4bb
kd initial experiment config
lindawangg Sep 10, 2024
8beaca0
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 10, 2024
b54929a
separated out loss func and added test
lindawangg Sep 11, 2024
b31c56d
added documentation
lindawangg Sep 11, 2024
fe5ed97
added prereq command to config
lindawangg Sep 11, 2024
8b9ea41
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 11, 2024
3f7fe70
re-add 8B config
lindawangg Sep 11, 2024
a87aa0c
added kd ratio
lindawangg Sep 11, 2024
f5feac4
revert 8b config
lindawangg Sep 11, 2024
8c3c42a
add kd recipe test
lindawangg Sep 12, 2024
6ba0514
mark as integration test
lindawangg Sep 12, 2024
04ea649
add save and load weights test
lindawangg Sep 12, 2024
62faa1d
fix comments 1
lindawangg Sep 13, 2024
bf15406
address kd loss test comments
lindawangg Sep 13, 2024
ac9eb0e
change to qwen2
lindawangg Sep 13, 2024
87a80b6
addressing recipe comments
lindawangg Sep 15, 2024
1fc3f64
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 15, 2024
106aa3e
distributed recipe
lindawangg Sep 16, 2024
0f4e922
remove todo comment and test activation checkpointing
lindawangg Sep 16, 2024
22fddca
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 16, 2024
526a4dc
Merge branch 'add-initial-kd-recipe' into add-kd-distributed
lindawangg Sep 16, 2024
0bb49dc
qwen2 distributed recipe
lindawangg Sep 17, 2024
c73857d
added to recipe registry
lindawangg Sep 17, 2024
59eff44
fdsp teacher model
lindawangg Sep 17, 2024
85d76bb
added kd distributed test
lindawangg Sep 18, 2024
dba57c4
fixed command
lindawangg Sep 18, 2024
04e2282
Merge branch 'main' into add-kd-distributed
lindawangg Sep 19, 2024
44123b9
changed to knowledge_distillation
lindawangg Sep 20, 2024
a04244d
cleaned up tests
lindawangg Sep 20, 2024
1ff9934
added gpu test
lindawangg Sep 20, 2024
703e7dc
Merge branch 'main' into add-kd-distributed
lindawangg Sep 24, 2024
0031bfb
Merge branch 'main' into add-kd-distributed
lindawangg Oct 15, 2024
307791d
added llama3 config and addressed comments
lindawangg Oct 15, 2024
fefc24d
added custom sharding layers
lindawangg Oct 15, 2024
15c5be2
Merge branch 'main' into add-kd-distributed
lindawangg Oct 21, 2024
46473ee
add test_loss back in
lindawangg Oct 22, 2024
2e212ec
Merge branch 'main' into add-kd-distributed
lindawangg Oct 24, 2024
557396e
rebase
lindawangg Oct 24, 2024
4d376e3
Merge branch 'main' into add-kd-distributed
lindawangg Oct 25, 2024
53c47ba
grad accumulation changes
lindawangg Oct 25, 2024
f193d02
remove extra num_tokens
lindawangg Oct 26, 2024
227e69d
addressed comments
lindawangg Oct 28, 2024
cf5f01a
Merge branch 'main' into add-kd-distributed
lindawangg Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions recipes/configs/llama3_2/knowledge_distillation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py
# using a teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's also worthwhile to add a config with 70B model size? (Doesn't necessarily have to be in this PR, but it'd be useful to have at least one config that strictly requires distributed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up to that, I wonder if we should include model sizes in the config names? I know it makes it a bit longer (and doesn't line up with what you did for the single-device configs), but otherwise we cannot really distinguish between configs for distilling 70B -> 1B vs 8B -> 1B. Similar to the other comment here, this is fine to save for a follow-up though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this config on the 70B model and verified it works, but I think there has to be more tuning. We can add the 70B model in a separate PR and figure out how to change the naming. There wasn't many changes to add the 70B model, just the model target and checkpoint since tokenizer has to be the same right now.

#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora
#
# To launch on 2 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed
#
# This config works best for distilling on 2+ devices.


# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 64
lora_alpha: 128
lora_dropout: 0.0

teacher_model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-1B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Teacher checkpoint
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 32

# Logging
output_dir: /tmp/kd_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.training.setup_torch_profiler

enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
123 changes: 123 additions & 0 deletions recipes/configs/qwen2/knowledge_distillation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py
# using a teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None
# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora
#
# To launch on 2 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed
#
# This config works best for distilling on 2+ devices.


# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
lora_rank: 32
lora_alpha: 64

teacher_model:
_component_: torchtune.models.qwen2.qwen2_1_5b

tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct-kd
model_type: QWEN2

teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
checkpoint_files: [
hf_model_0001_0.pt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default .pt? The assumption is that we finetuned the model on a target dataset first?

If so, the first example (Llama3.2) is wrong b/c the default files are safetensors, which are only saved if the checkpointer specifies safe_serialization: True. We should be consistent across these defaults.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the default lora distributed finetune configs for qwen2 and llama3.1 8b. I'm not sure why qwen2/1.5B_lora outputs .pt whereas llama3_1/8B_lora outputs .safetensors.

]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
model_type: QWEN2

resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 8

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 2

# Logging
output_dir: /tmp/qwen_kd
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.training.setup_torch_profiler

enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
Loading
Loading