Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fused = true to adam, except pagedAdam #1575

Merged
merged 3 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
fused: True # Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True

Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ quantizer:
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qdora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 0 additions & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False

fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ batch_size: 2
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-6
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-6
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
Expand Down
Loading
Loading