Skip to content

Commit

Permalink
fix load partition
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed May 11, 2024
1 parent 0305c59 commit 0e22e96
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
offset_st = max(storage_st - param_st, 0)
to_offset_st = offset_st + param_st - storage_st
if not config['load_param_gather']:
partition_numel= len(contiguous_param)
partition_numel= contiguous_param.numel()
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (partition_numel,))[:] = \
contiguous_param[:]
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), 0, (partition_numel,))[:]
continue

tp_split_dim = param._tp_split_dim
Expand Down Expand Up @@ -771,4 +771,4 @@ def add_tail(self, module, use_checkpoint=False):
return DummyForward
else:
self._add_tail(module)
return module
return module

0 comments on commit 0e22e96

Please sign in to comment.