Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.2 knowledge distillation config #1690

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions recipes/configs/llama3_2/knowledge_distillation_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Config for single device knowledge distillation (KD) in knowledge_distillation_single_device.py
# using a LLAMA3 teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
#
# To launch on a single device, run the following command from root:
# tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device
#
# This config works only for training on single device.


# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 64
lora_alpha: 128
lora_dropout: 0.0

teacher_model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-1B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Teacher checkpoint
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
hf_model_0001_0.pt,
hf_model_0002_0.pt,
hf_model_0003_0.pt,
hf_model_0004_0.pt
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 32
compile: False

# Logging
output_dir: /tmp/kd_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
4 changes: 4 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ class Recipe:
name="qwen2/knowledge_distillation_single_device",
file_path="qwen2/knowledge_distillation_single_device.yaml",
),
Config(
name="llama3_2/knowledge_distillation_single_device",
file_path="llama3_2/knowledge_distillation_single_device.yaml",
),
],
supports_distributed=False,
),
Expand Down
Loading