diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index cd792a88..46c47572 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Iterator, Tuple, Union +from typing import Dict, Iterable, Iterator, Union from .global_var import config @@ -8,7 +8,6 @@ from .parameter import DistributedParameter, OpAllGather from .checkpointing import ScopedTensorInspectorContext from . import debug -from torch.nn.modules.module import _addindent import copy def round_up(x, d): @@ -331,7 +330,8 @@ def __init__(self, inner_module : torch.nn.Module): # calc total number of parameters for name, param in ordered_parameters: - assert isinstance(param, DistributedParameter), "All parameters in checkpoint block must be DistributedParameter." + if not isinstance(param, DistributedParameter): + raise ValueError("All parameters in checkpoint block must be DistributedParameter.") storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) @@ -464,7 +464,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here with torch.no_grad(): with CheckpointBlockContext(self): - return self._module.state_dict(destination, prefix, keep_vars) + return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 4088f4ec..9cfff1b9 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -32,13 +32,16 @@ def all_gather(x : torch.Tensor): Returns: torch.Tensor: The gathered tensor of shape (world_size, ...). """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + assert x.is_cuda return OpAllGather.apply(x) class OpAllReduce(torch.autograd.Function): @staticmethod def forward(ctx, input : torch.Tensor, op : str): - if not input.contiguous(): + if not input.is_contiguous(): input = input.contiguous() if input.storage_offset() != 0 or input.storage().size() != input.numel(): input = input.clone() @@ -82,6 +85,9 @@ def all_reduce(x : torch.Tensor, op : str = "sum"): torch.Tensor: The reduced tensor of shape (...). """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + assert x.is_cuda return OpAllReduce.apply(x, op) diff --git a/bmtrain/global_var.py b/bmtrain/global_var.py index 909711f5..67a54cf7 100644 --- a/bmtrain/global_var.py +++ b/bmtrain/global_var.py @@ -14,10 +14,11 @@ class ConfigMap(TypedDict): loss_scale_steps : int gradient_inspect : bool + initialized : bool comm : 'NCCLCommunicator' -config = ConfigMap() +config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False) def rank(): """ diff --git a/bmtrain/init.py b/bmtrain/init.py index 92b8099f..42186923 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -6,7 +6,6 @@ from .utils import print_dict from .global_var import config from . import nccl -import time from .synchronize import synchronize def init_distributed( init_method : str = "env://", @@ -57,6 +56,7 @@ def init_distributed( store = dist.PrefixStore("bmtrain", store) torch.cuda.set_device(local_rank) + config["initialized"] = True config["local_rank"] = local_rank config["local_size"] = local_size config["rank"] = rank @@ -110,3 +110,6 @@ def init_distributed( "cpus": cpus_this_worker }) synchronize() + +def is_initialized() -> bool: + return config["initialized"] diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index 8e868522..f965cdac 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -33,6 +33,9 @@ def __new__(cls, init_method : Optional[Callable[['DistributedParameter'], None]] = None, group : Optional[str] = None ): + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + num_of_elements = data.numel() cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") diff --git a/bmtrain/store.py b/bmtrain/store.py index 95ce29e6..bf1c00f7 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -5,13 +5,14 @@ from .block_layer import CheckpointBlock from . import nccl import io, pickle +from typing import Mapping def _save_to_state_dict(model : torch.nn.Module, destination, prefix): if isinstance(model, CheckpointBlock): if config['rank'] != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() - model.state_dict(destination, prefix, False) + model.state_dict(destination=destination, prefix=prefix, keep_vars=False) else: if config['rank'] != 0: destination = OrderedDict() # creates an temporary ordered dict @@ -109,8 +110,8 @@ def broadcast_object(obj): obj = _unpickler(io.BytesIO(buf)).load() return obj - -class DistributedStateDictWrapper: +# Must be a Mapping after pytorch 1.12.0 +class DistributedStateDictWrapper(Mapping): def __init__(self, state_dict : Dict) -> None: self._state_dict = state_dict self._metadata = broadcast_object(getattr(state_dict, "_metadata", None)) @@ -176,6 +177,10 @@ def __contains__(self, key : str): def keys(self): return broadcast_object(list(self._state_dict.keys())) + def __iter__(self): + # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`. + return iter(self.keys()) + def load(model : torch.nn.Module, file_name : str, strict : bool = True): """Loads the model from the file. diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index f4f9b782..f704d159 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -1,17 +1,24 @@ import torch -from . import nccl +from . import distributed, nccl from .global_var import config +import warnings def synchronize(): """ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized) """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + with torch.cuda.stream(config['barrier_stream']): barrier = torch.cuda.FloatTensor([1]) nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm']) config['barrier_stream'].synchronize() def wait_loader(): + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + # wait lastest loader event, and set a new one config['load_event'].synchronize() config['calc_stream'].record_event(config['load_event']) @@ -23,22 +30,23 @@ def sum_loss(loss : torch.Tensor): This is a helper function to reduce the loss across all workers. """ - ret = torch.empty_like(loss) - nccl.allReduce( - loss.storage(), - ret.storage(), - 'avg', - config['comm'] - ) - return ret + warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning) + return distributed.all_reduce(loss, "avg") def gather_result(result: torch.Tensor): + warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning) + + output_cuda = True if not result.is_cuda: result = result.cuda() + output_cuda = False ret = torch.empty((result.shape[0]*config['world_size'], *list(result.shape[1:])), device=result.device, dtype=result.dtype) nccl.allGather( result.storage(), ret.storage(), config['comm'] ) - return ret \ No newline at end of file + if output_cuda: + return ret + else: + return ret.cpu()