Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate state dict API to DSD #1930

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -554,6 +555,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -662,6 +664,7 @@ def save_checkpoint(
log.info("Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -670,7 +673,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
Expand Down Expand Up @@ -900,7 +903,7 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
init_process_group(backend="cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
Expand Down
8 changes: 7 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -409,11 +410,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -511,6 +516,7 @@ def save_checkpoint(
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
8 changes: 7 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -549,11 +550,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -679,6 +684,7 @@ def save_checkpoint(
if self._is_rank_zero:
log.info("Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
8 changes: 7 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def setup(self, cfg: DictConfig) -> None:
self._tokenizer = config.instantiate(cfg.tokenizer)

self._optimizer = self._setup_optimizer(
model=self._model,
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
Expand Down Expand Up @@ -470,11 +471,15 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
self,
model: nn.Module,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -562,6 +567,7 @@ def save_checkpoint(

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
)
Expand Down
2 changes: 2 additions & 0 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def test_lora_state_dict(self):
fsdp_model_to_save, is_rank_zero
)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_model_to_save,
fsdp_optim_to_save,
is_rank_zero,
)
Expand Down Expand Up @@ -371,6 +372,7 @@ def test_lora_state_dict(self):
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
)
training.load_from_full_optimizer_state_dict(
fsdp_model_to_load,
fsdp_optim_to_load,
# mimic mmap=True where every rank see full SD
copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),
Expand Down
146 changes: 41 additions & 105 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
from torch import nn

from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
)
from torch.distributed.checkpoint.state_dict import _init_optim_state
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim import Optimizer
Expand Down Expand Up @@ -297,12 +303,18 @@ def load_from_full_model_state_dict(
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
has_nf4 = any(
hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
for param_name in full_sd.keys():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
full_sd[param_name] = (
full_sd[param_name].to(sharded_meta_param.dtype).to(device)
)

if has_nf4:
for param_name, full_tensor in full_sd.items():
full_tensor = to_nf4(full_tensor)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
Expand Down Expand Up @@ -334,18 +346,19 @@ def load_from_full_model_state_dict(
),
requires_grad=sharded_meta_param.requires_grad,
)

else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
else:
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=False,
strict=strict,
cpu_offload=cpu_offload,
)
set_model_state_dict(model=model, model_state_dict=full_sd, options=options)


def get_full_model_state_dict(
Expand Down Expand Up @@ -415,25 +428,13 @@ def get_full_model_state_dict(
cpu_state_dict[full_fqn] = param.cpu()
module.reshard()
else:
for param_name, sharded_param in sharded_sd.items():
# without this, it may hang forever for +70B models.
torch.distributed.barrier()
if sharded_param.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_model_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_param = sharded_param.to(device)
full_param = sharded_param.full_tensor()
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
cpu_state_dict = get_model_state_dict(model=model, options=options)
return cpu_state_dict


def get_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
is_rank_zero: bool,
device: Optional[torch.device] = None,
Expand All @@ -444,45 +445,12 @@ def get_full_optimizer_state_dict(
"exp_avg.full_tensor()" converts it to plain tensor on rank 0
Returning non-empty cpu state dict on rank 0
"""
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
# without this, it may hang forever for +70B models.
torch.distributed.barrier()
# "exp_avg" in AdamW is `DTensor`
if isinstance(sharded_tensor, DTensor):
if sharded_tensor.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_optimizer_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_tensor = sharded_tensor.to(device)
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
if is_rank_zero:
full_state[group_id] = group_state
else:
del group_state
if is_rank_zero:
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}
else:
return {}
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
return get_optimizer_state_dict(model=model, optimizers=opt, options=options)


def load_from_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
full_sd: Dict[str, Any],
device: torch.device,
Expand All @@ -491,41 +459,9 @@ def load_from_full_optimizer_state_dict(
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
PARAMS = "params" # noqa: N806
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]

full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]

for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
set_optimizer_state_dict(
model=model, optimizers=opt, optim_state_dict=full_sd, options=options
)


Expand Down
Loading