diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 74047cf8..44b47753 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -399,7 +399,12 @@ def __init__(self, inner_module : torch.nn.Module): to_offset_end = offset_end + param_st - storage_st # copy to buffer - self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) del contiguous_param # clear parameter data, but keep the dtype and device @@ -472,7 +477,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, to_offset_end = offset_end + param_st - storage_st # copy to buffer - self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) del contiguous_param elif strict: missing_keys.append(key) @@ -527,7 +537,12 @@ def init_parameters(self): to_offset_end = offset_end + param_st - storage_st # copy to buffer - self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end]) + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] + # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end]) del tmp_tensor def _named_members(self, get_members_fn, prefix='', recurse=True): diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 5da53239..f5d1ea37 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -22,7 +22,9 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): param._init_method(tmp_tensor) - param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) + # Pytorch 1.11 changed the API of storage.__getitem__ + param[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)] + # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) def iterate_parameters(model : torch.nn.Module): for kw, val in model._parameters.items(): diff --git a/setup.py b/setup.py index 283d4b73..977e855f 100644 --- a/setup.py +++ b/setup.py @@ -11,16 +11,39 @@ def get_avx_flags(): else: return ["-march=native"] +def get_device_cc(): + try: + CC_SET = set() + for i in range(torch.cuda.device_count()): + CC_SET.add(torch.cuda.get_device_capability(i)) + + if len(CC_SET) == 0: + return None + + ret = "" + for it in CC_SET: + if len(ret) > 0: + ret = ret + " " + ret = ret + ("%d.%d" % it) + return ret + except RuntimeError: + return None avx_flag = get_avx_flags() - -if not torch.cuda.is_available(): - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX") -else: - if torch.version.cuda.startswith("10"): - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5+PTX") +device_cc = get_device_cc() +if device_cc is None: + if not torch.cuda.is_available(): + os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX") else: - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0 8.6+PTX") + if torch.version.cuda.startswith("10"): + os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5+PTX") + else: + if not torch.version.cuda.startswith("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0 8.6+PTX") + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX") +else: + os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", device_cc) setup( name='bmtrain',