Skip to content

Commit

Permalink
support general param_id
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Dec 17, 2024
1 parent 3bafe2a commit d0e2baa
Showing 8 changed files with 162 additions and 42 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1488,7 +1488,7 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)

def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -404,7 +404,7 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)

def configure(
176 changes: 145 additions & 31 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any

import torch
import torch.nn as nn
@@ -52,10 +51,40 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator="-")
checkpoint = load_flat(checkpoint, seperator=".")
else:
checkpoint = utils.load_state_dict(checkpoint)

fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in checkpoint["state"].items():
new_state[id2name[int(key)]] = value
checkpoint["state"] = new_state
for g in checkpoint["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group

sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)

@@ -70,18 +99,19 @@ def save_unsharded_model(
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)

def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@@ -91,20 +121,48 @@ def save_unsharded_optimizer(
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()

full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save

flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator="-")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

if self.coordinator.is_master():

# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0

def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
full_optimizer_state["param_groups"] = param_groups
new_state = {}
for key, value in full_optimizer_state["state"].items():
new_state[name2id[key]] = value
full_optimizer_state["state"] = new_state

if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

def save_sharded_model(
self,
@@ -150,7 +208,7 @@ def save_sharded_model(
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
is_master=self.coordinator.is_master(),
)
self.async_writers.extend(writers)
else:
@@ -234,6 +292,32 @@ def save_sharded_optimizer(
)

if self.coordinator.is_master():

# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0

def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
fsdp_optim_state["param_groups"] = param_groups
new_state = {}
for key, value in fsdp_optim_state["state"].items():
new_state[name2id[key]] = value
fsdp_optim_state["state"] = new_state

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
prefix, use_safetensors=use_async
@@ -261,7 +345,7 @@ def save_sharded_optimizer(
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
@@ -306,13 +390,43 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator="-")
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard)

fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)

fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in fsdp_optim_dict["state"].items():
new_state[id2name[int(key)]] = value
fsdp_optim_dict["state"] = new_state
for g in fsdp_optim_dict["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group

with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
11 changes: 8 additions & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -70,16 +70,19 @@ def __init__(
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__()
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.sp_rank = dist.get_rank(self.sp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
@@ -490,7 +493,7 @@ def save_sharded_optimizer(

# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
control_saving = self.dp_rank == 0 and self.tp_rank == 0
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0

if use_async and control_saving:
if id(optimizer) not in self.pinned_state_dicts:
@@ -560,8 +563,10 @@ def save_sharded_optimizer(
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)

# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
if not use_async:
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
else:
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

3 changes: 2 additions & 1 deletion colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
@@ -44,12 +44,13 @@ def __init__(
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
@@ -307,7 +307,7 @@ def async_save_state_dict_shards(
checkpoint_file_path = os.path.join(checkpoint, shard_file)

if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator="-")
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else:
state_dict = shard

4 changes: 2 additions & 2 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def _cast_to_object(tensor: torch.Tensor):
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())


def _flatten_optim_state_dict(state_dict: dict, seperator: str = "-") -> Tuple[dict, Optional[dict]]:
def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
flat_dict = {}
non_tensor_keys = []
if "state" in state_dict:
@@ -196,7 +196,7 @@ def move_and_save(
return f_writer


def load_flat(checkpoint_path, seperator: str = "-"):
def load_flat(checkpoint_path, seperator: str = "."):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
4 changes: 2 additions & 2 deletions tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -114,8 +114,8 @@ def run_model():

run_model()

booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=False)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=True)
booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async)

booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()

0 comments on commit d0e2baa

Please sign in to comment.