From 1b68328de51d7947bafa7645d76f6d2270bf0682 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 24 Aug 2024 07:20:15 +0800 Subject: [PATCH] add default --- .../prototype/low_bit_optim/cpu_offload.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index 69ee4c240..1a4ed5816 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -1,20 +1,33 @@ from typing import Type import torch -from torch.optim.optimizer import Optimizer +from torch.optim.optimizer import Optimizer, ParamsT + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 class CPUOffloadOptimizer: - def __init__(self, params, optimizer_class: Type[Optimizer], *, offload_gradients: bool = False, **kwargs) -> None: + def __init__( + self, + params: ParamsT, + optimizer_class: Type[Optimizer] = torch.optim.AdamW, + *, + offload_gradients: bool = False, + **kwargs, + ) -> None: """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state. Optimizer step will be done on CPU. Args params: a list of parameters or parameter groups. - optimizer_class: constructor of the base optimizer. + optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`. offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation. kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ + # default to fused CPU AdamW + if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs: + kwargs.update(fused=True) + param_groups = list(params) if len(param_groups) == 0: raise ValueError("optimizer got an empty parameter list")