Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upd 0708 #38

Merged
merged 4 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Iterator, Tuple, Union
from typing import Dict, Iterable, Iterator, Union


from .global_var import config
Expand All @@ -8,7 +8,6 @@
from .parameter import DistributedParameter, OpAllGather
from .checkpointing import ScopedTensorInspectorContext
from . import debug
from torch.nn.modules.module import _addindent
import copy

def round_up(x, d):
Expand Down Expand Up @@ -331,7 +330,8 @@ def __init__(self, inner_module : torch.nn.Module):

# calc total number of parameters
for name, param in ordered_parameters:
assert isinstance(param, DistributedParameter), "All parameters in checkpoint block must be DistributedParameter."
if not isinstance(param, DistributedParameter):
raise ValueError("All parameters in checkpoint block must be DistributedParameter.")

storage_type = storage_type_cuda(param.storage_type())
kw_name = _get_param_kw(param)
Expand Down Expand Up @@ -464,7 +464,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
# gather here
with torch.no_grad():
with CheckpointBlockContext(self):
return self._module.state_dict(destination, prefix, keep_vars)
return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand Down
8 changes: 7 additions & 1 deletion bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ def all_gather(x : torch.Tensor):
Returns:
torch.Tensor: The gathered tensor of shape (world_size, ...).
"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

assert x.is_cuda
return OpAllGather.apply(x)

class OpAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input : torch.Tensor, op : str):
if not input.contiguous():
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
input = input.clone()
Expand Down Expand Up @@ -82,6 +85,9 @@ def all_reduce(x : torch.Tensor, op : str = "sum"):
torch.Tensor: The reduced tensor of shape (...).

"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

assert x.is_cuda
return OpAllReduce.apply(x, op)

Expand Down
3 changes: 2 additions & 1 deletion bmtrain/global_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class ConfigMap(TypedDict):
loss_scale_steps : int

gradient_inspect : bool
initialized : bool

comm : 'NCCLCommunicator'

config = ConfigMap()
config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False)

def rank():
"""
Expand Down
5 changes: 4 additions & 1 deletion bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .utils import print_dict
from .global_var import config
from . import nccl
import time
from .synchronize import synchronize
def init_distributed(
init_method : str = "env://",
Expand Down Expand Up @@ -57,6 +56,7 @@ def init_distributed(
store = dist.PrefixStore("bmtrain", store)
torch.cuda.set_device(local_rank)

config["initialized"] = True
config["local_rank"] = local_rank
config["local_size"] = local_size
config["rank"] = rank
Expand Down Expand Up @@ -110,3 +110,6 @@ def init_distributed(
"cpus": cpus_this_worker
})
synchronize()

def is_initialized() -> bool:
return config["initialized"]
3 changes: 3 additions & 0 deletions bmtrain/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __new__(cls,
init_method : Optional[Callable[['DistributedParameter'], None]] = None,
group : Optional[str] = None
):
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

num_of_elements = data.numel()

cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda")
Expand Down
11 changes: 8 additions & 3 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from .block_layer import CheckpointBlock
from . import nccl
import io, pickle
from typing import Mapping

def _save_to_state_dict(model : torch.nn.Module, destination, prefix):
if isinstance(model, CheckpointBlock):
if config['rank'] != 0:
destination = OrderedDict() # creates an temporary ordered dict
destination._metadata = OrderedDict()
model.state_dict(destination, prefix, False)
model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
else:
if config['rank'] != 0:
destination = OrderedDict() # creates an temporary ordered dict
Expand Down Expand Up @@ -109,8 +110,8 @@ def broadcast_object(obj):
obj = _unpickler(io.BytesIO(buf)).load()
return obj


class DistributedStateDictWrapper:
# Must be a Mapping after pytorch 1.12.0
class DistributedStateDictWrapper(Mapping):
def __init__(self, state_dict : Dict) -> None:
self._state_dict = state_dict
self._metadata = broadcast_object(getattr(state_dict, "_metadata", None))
Expand Down Expand Up @@ -176,6 +177,10 @@ def __contains__(self, key : str):
def keys(self):
return broadcast_object(list(self._state_dict.keys()))

def __iter__(self):
# pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
return iter(self.keys())

def load(model : torch.nn.Module, file_name : str, strict : bool = True):
"""Loads the model from the file.

Expand Down
28 changes: 18 additions & 10 deletions bmtrain/synchronize.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import torch
from . import nccl
from . import distributed, nccl
from .global_var import config
import warnings

def synchronize():
"""
Synchronize all the workers across all nodes. (both CPU and GPU are synchronized)
"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

with torch.cuda.stream(config['barrier_stream']):
barrier = torch.cuda.FloatTensor([1])
nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm'])
config['barrier_stream'].synchronize()

def wait_loader():
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

# wait lastest loader event, and set a new one
config['load_event'].synchronize()
config['calc_stream'].record_event(config['load_event'])
Expand All @@ -23,22 +30,23 @@ def sum_loss(loss : torch.Tensor):

This is a helper function to reduce the loss across all workers.
"""
ret = torch.empty_like(loss)
nccl.allReduce(
loss.storage(),
ret.storage(),
'avg',
config['comm']
)
return ret
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
return distributed.all_reduce(loss, "avg")

def gather_result(result: torch.Tensor):
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)

output_cuda = True
if not result.is_cuda:
result = result.cuda()
output_cuda = False
ret = torch.empty((result.shape[0]*config['world_size'], *list(result.shape[1:])), device=result.device, dtype=result.dtype)
nccl.allGather(
result.storage(),
ret.storage(),
config['comm']
)
return ret
if output_cuda:
return ret
else:
return ret.cpu()