Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async save state_dict to file #171

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import nccl
import io, pickle
from typing import Mapping
import threading

def _save_to_state_dict(model : torch.nn.Module, destination, prefix):
if isinstance(model, Block):
Expand Down Expand Up @@ -40,23 +41,40 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
return destination


def async_save_to_file(state_dict, file_path):
torch.save(state_dict, file_path)
config['finish_save'] = True
print("finish save state_dict to ", file_path)

def save(model : torch.nn.Module, file_name : str):
def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True):
"""Saves the model to the file.

Similar to torch.save, but it used for distributed modules.

Args:
model (torch.nn.Module): The model to be saved.
file_name (str): The file name of the checkpoint.
non_blocking (bool): Whether to asynchronously save state_dict to file


Examples:
>>> bmtrain.save(model, "model.pt")
"""
torch.cuda.synchronize()
state_dict = _save_to_rank0(model)
if config["rank"] == 0:
torch.save(state_dict, file_name)
if non_blocking is False:
torch.save(state_dict, file_name)
else:
if 'finish_save' not in config:
config['finish_save'] = True

if config['finish_save'] is False:
config['save_thread'].join()

config['finish_save'] = False
config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name))
config['save_thread'].start()

DTYPE_LIST = [
torch.float64,
Expand Down