Skip to content

Commit

Permalink
add fused = true to adam, except pagedAdam (#1575)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <felipemello@fb.com>
  • Loading branch information
felipemello1 and Felipe Mello authored Sep 14, 2024
1 parent 60cf96f commit cca50f0
Show file tree
Hide file tree
Showing 58 changed files with 57 additions and 11 deletions.
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
fused: True # Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True

Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ quantizer:
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qdora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 0 additions & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-6
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-6
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
Loading

0 comments on commit cca50f0

Please sign in to comment.