Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt] Add async ckpt api #6136

Merged
merged 12 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
DISABLE_URING=1 pip install -v --no-cache-dir .

- name: Store TensorNVMe Cache
run: |
Expand Down
1 change: 1 addition & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def save_model(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
"""
self.checkpoint_io.save_model(
model,
Expand Down
57 changes: 37 additions & 20 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
Expand All @@ -74,7 +81,10 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Expand Down Expand Up @@ -112,6 +122,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
Expand All @@ -130,27 +141,33 @@ def save_sharded_model(

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
Expand Down
8 changes: 7 additions & 1 deletion colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
"""
Save model to checkpoint but only on master process.
"""
Expand Down
4 changes: 2 additions & 2 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def save_model(

if shard:
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def save_unsharded_model(
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:

else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

Expand Down
77 changes: 55 additions & 22 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
Expand Down Expand Up @@ -177,6 +179,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
Expand All @@ -194,6 +197,7 @@ def save_sharded_model(
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
Expand All @@ -219,24 +223,27 @@ def save_sharded_model(

if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

else:
# When pipeline is used, each stage produces its own shard files and index files.
Expand All @@ -251,7 +258,16 @@ def save_sharded_model(
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
Expand Down Expand Up @@ -626,7 +642,9 @@ def _get_param_id_from_optimizer_param(
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model state dict to a single file with given checkpointing path.

Expand All @@ -635,6 +653,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
Expand All @@ -651,7 +670,10 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
Expand All @@ -662,7 +684,18 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
"""
Expand Down
52 changes: 33 additions & 19 deletions colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -161,24 +162,27 @@ def save_sharded_model(

if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

dist.barrier()
else:
Expand Down Expand Up @@ -708,10 +712,20 @@ def save_unsharded_model(
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
if use_async:
super().save_unsharded_model(
model=model,
checkpoint=checkpoint,
gather_dtensor=gather_dtensor,
use_safetensors=use_safetensors,
use_async=use_async,
)
else:
torch.save(state_dict, checkpoint)
dist.barrier()

# Copied from colossalai.moe
Expand Down
13 changes: 5 additions & 8 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# ======================================


def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
def save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
use_safetensors: bool,
) -> None:
"""
Save state dict to checkpoint.

Expand Down Expand Up @@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open

with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
Expand Down
Loading
Loading