Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
memory_efficient_fsdp_wrap: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
9 changes: 5 additions & 4 deletions recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 2
epochs: 3
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

# QAT arguments
quantizer:
Expand All @@ -60,13 +58,16 @@ loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
memory_efficient_fsdp_wrap: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.

# Reduced precision
dtype: bf16
Expand All @@ -75,7 +76,7 @@ dtype: bf16
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
output_dir: /tmp/full-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

Expand Down
Loading
Loading