Skip to content

Commit 3df9045

Browse files
committed
Update QAT: add grad clipping, torch.compile, collate fn
**Summary:** Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. **Test Plan:** TBD
1 parent f560cbb commit 3df9045

File tree

3 files changed

+236
-77
lines changed

3 files changed

+236
-77
lines changed

recipes/configs/llama2/7B_qat_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ device: cuda
6666

6767
# Memory management
6868
enable_activation_checkpointing: True
69-
memory_efficient_fsdp_wrap: False
69+
enable_activation_offloading: False # True reduces memory
7070

7171
# Reduced precision
7272
dtype: bf16

recipes/configs/llama3/8B_qat_full.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ resume_from_checkpoint: False
4444
# Fine-tuning arguments
4545
batch_size: 2
4646
epochs: 3
47-
compile: False
4847

4948
# QAT arguments
5049
quantizer:
@@ -59,13 +58,15 @@ loss:
5958
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
6059
max_steps_per_epoch: null
6160
gradient_accumulation_steps: 1
61+
compile: False
6262

6363
# Training env
6464
device: cuda
6565

6666
# Memory management
6767
enable_activation_checkpointing: True
68-
memory_efficient_fsdp_wrap: True
68+
enable_activation_offloading: False # True reduces memory
69+
custom_sharded_layers: ['tok_embeddings', 'output']
6970

7071
# Reduced precision
7172
dtype: bf16
@@ -74,6 +75,6 @@ dtype: bf16
7475
metric_logger:
7576
_component_: torchtune.training.metric_logging.DiskLogger
7677
log_dir: ${output_dir}
77-
output_dir: /tmp/alpaca-llama3-finetune
78+
output_dir: /tmp/full-llama3-finetune
7879
log_every_n_steps: 1
7980
log_peak_memory_stats: True

0 commit comments

Comments
 (0)