diff --git a/bmtrain/store.py b/bmtrain/store.py index 88ed7305..ae1101f1 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -8,6 +8,7 @@ from . import nccl import io, pickle from typing import Mapping +import threading def _save_to_state_dict(model : torch.nn.Module, destination, prefix): if isinstance(model, Block): @@ -40,8 +41,12 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): return destination +def async_save_to_file(state_dict, file_path): + torch.save(state_dict, file_path) + config['finish_save'] = True + print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. @@ -49,6 +54,8 @@ def save(model : torch.nn.Module, file_name : str): Args: model (torch.nn.Module): The model to be saved. file_name (str): The file name of the checkpoint. + non_blocking (bool): Whether to asynchronously save state_dict to file + Examples: >>> bmtrain.save(model, "model.pt") @@ -56,7 +63,18 @@ def save(model : torch.nn.Module, file_name : str): torch.cuda.synchronize() state_dict = _save_to_rank0(model) if config["rank"] == 0: - torch.save(state_dict, file_name) + if non_blocking is False: + torch.save(state_dict, file_name) + else: + if 'finish_save' not in config: + config['finish_save'] = True + + if config['finish_save'] is False: + config['save_thread'].join() + + config['finish_save'] = False + config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) + config['save_thread'].start() DTYPE_LIST = [ torch.float64,