From 30a5cccfe89d7d430d0ee819e4d5cd3e789feb80 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 30 Oct 2024 15:20:23 -0700 Subject: [PATCH 1/5] correct model_state_dict dtype, introduce model into optimizer state dict API --- recipes/full_finetune_distributed.py | 5 ++- recipes/lora_dpo_distributed.py | 8 ++++- recipes/lora_finetune_distributed.py | 8 ++++- recipes/qat_distributed.py | 8 ++++- tests/torchtune/training/test_distributed.py | 2 ++ torchtune/training/_distributed.py | 38 ++++++++++++-------- 6 files changed, 51 insertions(+), 18 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 6cda652fd3..75b6f6503b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -478,6 +478,7 @@ def _setup_optimizer( for param in opt_state_dict.keys(): try: training.load_from_full_optimizer_state_dict( + self._model, self._optim_ckpt_wrapper.state_dict()[param], opt_state_dict[param], self._device, @@ -494,6 +495,7 @@ def _setup_optimizer( optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: training.load_from_full_optimizer_state_dict( + self._model, optimizer, opt_state_dict, self._device, @@ -602,6 +604,7 @@ def save_checkpoint( log.info("Getting optimizer state dict...") if not self._optimizer_in_bwd: opt_state_dict = training.get_full_optimizer_state_dict( + self._model, self._optimizer, self._is_rank_zero, device=self._device, @@ -610,7 +613,7 @@ def save_checkpoint( opt_state_dict = {} for param, opt in self._optim_ckpt_wrapper.optim_map.items(): opt_state_dict[param] = training.get_full_optimizer_state_dict( - opt, self._is_rank_zero, device=self._device + self._model, opt, self._is_rank_zero, device=self._device ) if self._is_rank_zero: log.info( diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 1ab88deaf8..23fcf66a60 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -244,6 +244,7 @@ def setup(self, cfg: DictConfig) -> None: self._tokenizer = config.instantiate(cfg.tokenizer) self._optimizer = self._setup_optimizer( + model=self._model, cfg_optimizer=cfg.optimizer, opt_state_dict=( checkpoint_dict[training.OPT_KEY] @@ -409,11 +410,15 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + self, + model: nn.Module, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[Dict[str, Any]] = None, ) -> Optimizer: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: training.load_from_full_optimizer_state_dict( + model, optimizer, opt_state_dict, self._device, @@ -511,6 +516,7 @@ def save_checkpoint( ) if intermediate_checkpoint: opt_state_dict = training.get_full_optimizer_state_dict( + self._model, self._optimizer, self._is_rank_zero, device=self._device, diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 6f760dd16b..6c01675176 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -275,6 +275,7 @@ def setup(self, cfg: DictConfig) -> None: self._tokenizer = config.instantiate(cfg.tokenizer) self._optimizer = self._setup_optimizer( + model=self._model, cfg_optimizer=cfg.optimizer, opt_state_dict=( checkpoint_dict[training.OPT_KEY] @@ -549,11 +550,15 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + self, + model: nn.Module, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[Dict[str, Any]] = None, ) -> Optimizer: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: training.load_from_full_optimizer_state_dict( + model, optimizer, opt_state_dict, self._device, @@ -679,6 +684,7 @@ def save_checkpoint( if self._is_rank_zero: log.info("Retrieving optimizer state dict...") opt_state_dict = training.get_full_optimizer_state_dict( + self._model, self._optimizer, self._is_rank_zero, device=self._device, diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index afb8e8d0e8..f05908f90f 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -238,6 +238,7 @@ def setup(self, cfg: DictConfig) -> None: self._tokenizer = config.instantiate(cfg.tokenizer) self._optimizer = self._setup_optimizer( + model=self._model, cfg_optimizer=cfg.optimizer, opt_state_dict=( checkpoint_dict[training.OPT_KEY] @@ -470,11 +471,15 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + self, + model: nn.Module, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[Dict[str, Any]] = None, ) -> Optimizer: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: training.load_from_full_optimizer_state_dict( + model, optimizer, opt_state_dict, self._device, @@ -562,6 +567,7 @@ def save_checkpoint( if intermediate_checkpoint: opt_state_dict = training.get_full_optimizer_state_dict( + self._model, self._optimizer, self._is_rank_zero, ) diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 1f4b92b4de..1397dc6067 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -316,6 +316,7 @@ def test_lora_state_dict(self): fsdp_model_to_save, is_rank_zero ) optim_full_sd = training.get_full_optimizer_state_dict( + fsdp_model_to_save, fsdp_optim_to_save, is_rank_zero, ) @@ -371,6 +372,7 @@ def test_lora_state_dict(self): fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01 ) training.load_from_full_optimizer_state_dict( + fsdp_model_to_load, fsdp_optim_to_load, # mimic mmap=True where every rank see full SD copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)), diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index aa68364f88..a4dd1db2bc 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -297,10 +297,18 @@ def load_from_full_model_state_dict( """ meta_sharded_sd = model.state_dict() sharded_sd = {} - for param_name, full_tensor in full_sd.items(): + contain_nf4 = False + for param_name in full_sd.keys(): sharded_meta_param = meta_sharded_sd.get(param_name) - full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) + full_sd[param_name] = ( + full_sd[param_name].to(sharded_meta_param.dtype).to(device) + ) + # assumme that all the tensors are NF4Tensor or not, and process accordingly if isinstance(sharded_meta_param._local_tensor, NF4Tensor): + contain_nf4 = True + + if contain_nf4: + for param_name, full_tensor in full_sd.items(): full_tensor = to_nf4(full_tensor) # replicating logic from `_fsdp_param.py`` `_init_sharded_param` # otherwise `distribute_tensor(DTensor(local=NF4))` @@ -332,18 +340,18 @@ def load_from_full_model_state_dict( ), requires_grad=sharded_meta_param.requires_grad, ) - - else: - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - if cpu_offload: - sharded_tensor = sharded_tensor.cpu() - sharded_sd[param_name] = nn.Parameter(sharded_tensor) - # choose `assign=True` since we cannot call `copy_` on meta tensor - return model.load_state_dict(sharded_sd, strict=strict, assign=True) + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + return model.load_state_dict(sharded_sd, strict=strict, assign=True) + else: + # choose `assign=True` since we cannot call `copy_` on meta tensor + options = torch.distributed.checkpoint.state_dict.StateDictOptions( + full_state_dict=True, broadcast_from_rank0=True, strict=strict + ) + torch.distributed.checkpoint.state_dict.set_model_state_dict( + model=model, model_state_dict=full_sd, options=options + ) def get_full_model_state_dict( @@ -432,6 +440,7 @@ def get_full_model_state_dict( def get_full_optimizer_state_dict( + model: "FSDPModule", # noqa opt: Optimizer, is_rank_zero: bool, device: Optional[torch.device] = None, @@ -481,6 +490,7 @@ def get_full_optimizer_state_dict( def load_from_full_optimizer_state_dict( + model: "FSDPModule", # noqa opt: Optimizer, full_sd: Dict[str, Any], device: torch.device, From d4713f8aea05efaf7d42371f40e908a984ffb2c4 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 30 Oct 2024 16:12:01 -0700 Subject: [PATCH 2/5] get has_nf4 --- torchtune/training/_distributed.py | 32 ++++++++++-------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index a4dd1db2bc..8a4b0db5d0 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -297,17 +297,16 @@ def load_from_full_model_state_dict( """ meta_sharded_sd = model.state_dict() sharded_sd = {} - contain_nf4 = False + has_nf4 = any( + isinstance(param._local_tensor, NF4Tensor) for param in model.parameters() + ) for param_name in full_sd.keys(): sharded_meta_param = meta_sharded_sd.get(param_name) full_sd[param_name] = ( full_sd[param_name].to(sharded_meta_param.dtype).to(device) ) - # assumme that all the tensors are NF4Tensor or not, and process accordingly - if isinstance(sharded_meta_param._local_tensor, NF4Tensor): - contain_nf4 = True - if contain_nf4: + if has_nf4: for param_name, full_tensor in full_sd.items(): full_tensor = to_nf4(full_tensor) # replicating logic from `_fsdp_param.py`` `_init_sharded_param` @@ -343,9 +342,9 @@ def load_from_full_model_state_dict( if cpu_offload: sharded_tensor = sharded_tensor.cpu() sharded_sd[param_name] = nn.Parameter(sharded_tensor) + # choose `assign=True` since we cannot call `copy_` on meta tensor return model.load_state_dict(sharded_sd, strict=strict, assign=True) else: - # choose `assign=True` since we cannot call `copy_` on meta tensor options = torch.distributed.checkpoint.state_dict.StateDictOptions( full_state_dict=True, broadcast_from_rank0=True, strict=strict ) @@ -421,21 +420,12 @@ def get_full_model_state_dict( cpu_state_dict[full_fqn] = param.cpu() module.reshard() else: - for param_name, sharded_param in sharded_sd.items(): - # without this, it may hang forever for +70B models. - torch.distributed.barrier() - if sharded_param.is_cpu: - assert device is not None and device.type == "cuda", ( - f"Expect cuda but got device={device}. " - "Please call get_full_model_state_dict(..., device=self._device)," - " so DTensor can communicate over NCCL." - ) - sharded_param = sharded_param.to(device) - full_param = sharded_param.full_tensor() - if is_rank_zero: - cpu_state_dict[param_name] = full_param.cpu() - else: - del full_param + options = torch.distributed.checkpoint.state_dict.StateDictOptions( + full_state_dict=True, broadcast_from_rank0=True + ) + cpu_state_dict = torch.distributed.checkpoint.state_dict.get_model_state_dict( + model=model, options=options + ) return cpu_state_dict From d680f47ea0025741b43d8b7f4be6d55051635a5c Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 7 Nov 2024 19:25:30 -0800 Subject: [PATCH 3/5] mitigate get/load model state dict api --- recipes/full_finetune_distributed.py | 2 +- torchtune/training/_distributed.py | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 75b6f6503b..aa45417a04 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -830,7 +830,7 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + init_process_group(backend="cuda:nccl,cpu:gloo") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 8a4b0db5d0..9c8acf03e4 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -20,7 +20,12 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_WRAPPED_MODULE, ) -from torch.distributed.checkpoint.state_dict import _init_optim_state +from torch.distributed.checkpoint.state_dict import ( + _init_optim_state, + get_model_state_dict, + set_model_state_dict, + StateDictOptions, +) from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.optim import Optimizer @@ -345,12 +350,13 @@ def load_from_full_model_state_dict( # choose `assign=True` since we cannot call `copy_` on meta tensor return model.load_state_dict(sharded_sd, strict=strict, assign=True) else: - options = torch.distributed.checkpoint.state_dict.StateDictOptions( - full_state_dict=True, broadcast_from_rank0=True, strict=strict - ) - torch.distributed.checkpoint.state_dict.set_model_state_dict( - model=model, model_state_dict=full_sd, options=options + options = StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=False, + strict=strict, + cpu_offload=cpu_offload, ) + set_model_state_dict(model=model, model_state_dict=full_sd, options=options) def get_full_model_state_dict( @@ -420,12 +426,8 @@ def get_full_model_state_dict( cpu_state_dict[full_fqn] = param.cpu() module.reshard() else: - options = torch.distributed.checkpoint.state_dict.StateDictOptions( - full_state_dict=True, broadcast_from_rank0=True - ) - cpu_state_dict = torch.distributed.checkpoint.state_dict.get_model_state_dict( - model=model, options=options - ) + options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) + cpu_state_dict = get_model_state_dict(model=model, options=options) return cpu_state_dict From 7ffa80b56a921d6283fe9db2f9d2369aa235d6e8 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 7 Nov 2024 20:27:51 -0800 Subject: [PATCH 4/5] mitigate get/load optimizer state dict api --- torchtune/training/_distributed.py | 81 +++--------------------------- 1 file changed, 8 insertions(+), 73 deletions(-) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 9c8acf03e4..8b23927c6f 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -15,15 +15,16 @@ from torch import nn from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard -from torch.distributed._tensor import distribute_tensor, DTensor +from torch.distributed._tensor import DTensor from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_WRAPPED_MODULE, ) from torch.distributed.checkpoint.state_dict import ( - _init_optim_state, get_model_state_dict, + get_optimizer_state_dict, set_model_state_dict, + set_optimizer_state_dict, StateDictOptions, ) from torch.distributed.fsdp import ShardingStrategy @@ -443,42 +444,8 @@ def get_full_optimizer_state_dict( "exp_avg.full_tensor()" converts it to plain tensor on rank 0 Returning non-empty cpu state dict on rank 0 """ - sharded_sd = opt.state_dict() - sharded_state = sharded_sd["state"] - full_state = {} - for group_id, sharded_group in sharded_state.items(): - group_state = {} - for attr, sharded_tensor in sharded_group.items(): - # without this, it may hang forever for +70B models. - torch.distributed.barrier() - # "exp_avg" in AdamW is `DTensor` - if isinstance(sharded_tensor, DTensor): - if sharded_tensor.is_cpu: - assert device is not None and device.type == "cuda", ( - f"Expect cuda but got device={device}. " - "Please call get_full_optimizer_state_dict(..., device=self._device)," - " so DTensor can communicate over NCCL." - ) - sharded_tensor = sharded_tensor.to(device) - full_tensor = sharded_tensor.full_tensor() - else: - # "step" in AdamW is plain tensor - full_tensor = sharded_tensor - if is_rank_zero: - group_state[attr] = full_tensor.cpu() - else: - del full_tensor - if is_rank_zero: - full_state[group_id] = group_state - else: - del group_state - if is_rank_zero: - return { - "param_groups": sharded_sd["param_groups"], - "state": full_state, - } - else: - return {} + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + return get_optimizer_state_dict(model=model, optimizers=opt, options=options) def load_from_full_optimizer_state_dict( @@ -491,41 +458,9 @@ def load_from_full_optimizer_state_dict( Converting full optimizer state to sharded state dict and loading it into optimizer """ - PARAMS = "params" # noqa: N806 - _init_optim_state(opt) - param_groups = opt.state_dict()["param_groups"] - state = opt.state_dict()["state"] - - full_param_groups = full_sd["param_groups"] - full_state = full_sd["state"] - - for param_group, full_param_group in zip(param_groups, full_param_groups): - for key, value in full_param_group.items(): - if key == PARAMS: - continue - param_group[key] = value - for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): - if pid not in state: - continue - param_state = state[pid] - full_param_state = full_state[full_pid] - for attr, full_tensor in full_param_state.items(): - sharded_tensor = param_state[attr] - if isinstance(sharded_tensor, DTensor): - # exp_avg is DTensor - param_state[attr] = distribute_tensor( - full_tensor, - sharded_tensor.device_mesh, - sharded_tensor.placements, - ) - else: - # step is plain tensor - param_state[attr] = full_tensor - opt.load_state_dict( - { - "param_groups": param_groups, - "state": state, - } + options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) + set_optimizer_state_dict( + model=model, optimizers=opt, optim_state_dict=full_sd, options=options ) From 1ad30263c317e03a10c497d996e7490f603a9aad Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 8 Nov 2024 09:58:36 -0800 Subject: [PATCH 5/5] lint error --- torchtune/training/_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index fefe0863c5..87a205a3c3 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -304,7 +304,7 @@ def load_from_full_model_state_dict( meta_sharded_sd = model.state_dict() sharded_sd = {} has_nf4 = any( - hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) + hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) for param in model.parameters() ) for param_name in full_sd.keys():