Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _replace(self, child, name, conv_linear_layer):
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data
Expand Down Expand Up @@ -360,14 +360,14 @@ def _replace(self, child, name, conv_linear_layer):
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
Expand All @@ -386,11 +386,11 @@ def _slice_embedding(self, child, name, conv_linear_layer):
if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
else:
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size, name))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
self.world_size = world_size

def forward(self, input):
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
Expand Down
23 changes: 13 additions & 10 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,31 @@ def get_num_kv_heads():
return num_kv_heads


def get_shard_size(total_size, mp_size, rank=None):
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
if num_kv_heads is not None:
if rank is None:
rank = dist.get_rank()
last_linear = ["lm_head", "embed_out"]
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
name) not in last_linear:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size % mp_size == 0:
return total_size // mp_size
if total_size >= 64:
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)


def get_n_embd():
global n_embd
return n_embd


def get_shard_size_list(total_size, mp_size):
def get_shard_size_list(total_size, mp_size, name=None):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
shard_sizes.append(get_shard_size(total_size, mp_size, name, i))
return shard_sizes