Skip to content

Commit

Permalink
Splitting tokens when routing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 378316002
  • Loading branch information
Mesh TensorFlow Team committed Jun 9, 2021
1 parent 54b01b4 commit 16ad1b2
Showing 1 changed file with 191 additions and 148 deletions.
339 changes: 191 additions & 148 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(self,
ntlb_top_k=4,
output_dim=None,
use_experts_attention=False,
z_loss=None):
z_loss=None,
num_hidden_splits=None,
split_hidden_before_routing=False):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand All @@ -85,7 +87,9 @@ def __init__(self,
moe_output_dim=output_dim,
moe_ntlb_top_k=ntlb_top_k,
moe_use_experts_attention=use_experts_attention,
moe_z_loss=z_loss)
moe_z_loss=z_loss,
moe_num_hidden_splits=num_hidden_splits,
moe_split_hidden_before_routing=split_hidden_before_routing)
self._activation = activation

def call(self, context, x, losses=None):
Expand Down Expand Up @@ -327,8 +331,8 @@ def transformer_moe_layer_v1(
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])
batch_and_length_dims, orig_input_dim = (
orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1])
# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
Expand All @@ -348,16 +352,57 @@ def transformer_moe_layer_v1(

n = n // outer_batch_dim.size

mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
orig_batch_dim)
num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
mesh_dim_size)
# Create num_groups and group_size dimensions
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(
layout, mesh_shape, orig_batch_dim)
num_groups, group_size = _split_into_groups(
n, hparams.moe_group_size, mesh_dim_size)
orig_group_size_dim = mtf.Dimension("group", group_size)
orig_num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

# The original dimensions correspond to those before splitting tokens
# into subtokens
group_size_dim = orig_group_size_dim
num_groups_dim = orig_num_groups_dim
input_dim = orig_input_dim

split_hidden_before_routing = False
split_hidden_after_routing = False
if hparams.moe_num_hidden_splits is not None:
if orig_input_dim.size % hparams.moe_num_hidden_splits:
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
hparams.moe_num_hidden_splits, input_dim.size))
if output_dim.size % hparams.moe_num_hidden_splits:
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
hparams.moe_num_hidden_splits, input_dim.size))
split_hidden_before_routing = hparams.moe_split_hidden_before_routing
split_hidden_after_routing = not hparams.moe_split_hidden_before_routing
hidden_dim = mtf.Dimension(
"expert_hidden",
hparams.moe_hidden_size // hparams.moe_num_hidden_splits)
sub_output_dim = mtf.Dimension(
output_dim.name, output_dim.size // hparams.moe_num_hidden_splits)
num_splits_dim = mtf.Dimension(
"num_splits", hparams.moe_num_hidden_splits)

if split_hidden_before_routing:
input_dim = mtf.Dimension(
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)

# Split into groups and subtokens
inputs = mtf.reshape(
inputs, [outer_batch_dim, num_groups_dim, group_size_dim,
num_splits_dim, input_dim])

group_size_dim = mtf.Dimension("group", group_size)
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
inputs = mtf.transpose(
inputs, [outer_batch_dim, num_groups_dim, num_splits_dim,
group_size_dim, input_dim])

num_groups_dim = mtf.Dimension(
orig_batch_dim.name, num_groups * hparams.moe_num_hidden_splits)

# [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
# OGSM Tensor
inputs = mtf.reshape(inputs, moe_input_dims)

# Each sequence sends expert_capacity positions to each expert.
Expand All @@ -373,156 +418,138 @@ def transformer_moe_layer_v1(
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)

if nonpadding is not None:
nonpadding = mtf.zeros(
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding

if split_hidden_before_routing:
nonpadding = mtf.reshape(
nonpadding,
[outer_batch_dim, orig_num_groups_dim, orig_group_size_dim])

# Tile num_hidden_splits times with an einsum
tiling_tensor = mtf.ones(inputs.mesh, [num_splits_dim])
nonpadding = mtf.einsum(
[nonpadding, tiling_tensor],
output_shape=[outer_batch_dim, orig_num_groups_dim, num_splits_dim,
orig_group_size_dim])

nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
if hparams.moe_gating == "top_2":
# combine_tensor,
# dispatch_tensor OG`SEC Tensors
# (G is generally split along mesh dim)
dispatch_tensor, combine_tensor, loss = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = _switch_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
group_size_dim=group_size_dim,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

expert_inputs = mtf.einsum([inputs, dispatch_tensor],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim, input_dim
]))
# [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
gating_fn = get_gating_fn(hparams.moe_gating)
dispatch_tensor, combine_tensor, loss = gating_fn(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)

# Dispatch to the experts by reducing group_size_dim
# inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
# dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
# expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B,
# expert_capacity_dim, input_dim]
expert_inputs_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, input_dim]
expert_inputs = mtf.einsum([inputs, dispatch_tensor], expert_inputs_shape)

# Split over batch -> split over experts
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
# expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
# expert_capacity_dim, input_dim or input_dim.M]
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
d_model_split_dim
]))

# Split over batch -> split over experts
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
input_dim
]))

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)

def _compute_output(hidden, layer_name):
"""Compute the output of the attention layer from the hidden vector."""
expert_inputs_shape = [
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim, d_model_split_dim]
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)

expert_inputs_shape = [
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim, input_dim]
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)

def _apply_experts(x, output_dim, hidden_dim):
# x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
# expert_capacity_dim, input_dim]
h = mtf.layers.dense_product(
x,
reduced_dims=x.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
expert_output = mtf.layers.dense(
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
name=layer_name)

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))

# Split over experts -> split over batch
h, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
name="wo")

return expert_output

if split_hidden_after_routing:
input_dim = mtf.Dimension(
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)
expert_inputs = mtf.reshape(
expert_inputs, expert_inputs.shape[:-1] + [num_splits_dim, input_dim])
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
# Concat sub_tokens into tokens
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output

if hparams.moe_use_experts_attention:
# We share k_h and v_h with no degradation in performance
q_h, k_h = h, h
outputs = []
q = _compute_output(q_h, layer_name="q_wo")
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
expert_output, expert_output.shape[:-2] + [output_dim])
elif split_hidden_before_routing:
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
expert_output = _apply_experts(expert_inputs, output_dim, hidden_dim)

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim]
expert_output = mtf.reshape(expert_output, expert_output_shape)

# Split over experts -> split over batch
expert_output_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, expert_output.shape[-1]]
expert_output = mtf.reshape(expert_output, expert_output_shape)

# Combine by reducing experts_dim_unsplit and expert_capacity_dim
# expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim,
# expert_capacity_dim, output_dim]
# combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
# output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
moe_output_dims = moe_input_dims[:-1] + [expert_output.shape[-1]]
output = mtf.einsum([expert_output, combine_tensor], moe_output_dims)
# import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top

if split_hidden_before_routing:
output = mtf.reshape(
output, [output.shape[0], orig_num_groups_dim, num_splits_dim] + (
output.shape[-2:]))
output = mtf.transpose(
output, output.shape[:2] + [
group_size_dim, num_splits_dim, output.shape[-1]])
output = mtf.reshape(output, output.shape[:3] + [output_dim])

output = mtf.reshape(output, batch_and_length_dims + [output_dim])

return output, loss * hparams.moe_loss_coef


def transformer_moe_layer_v2(
Expand Down Expand Up @@ -801,6 +828,22 @@ def transformer_moe_layer_v2(
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef


def get_gating_fn(moe_gating):
"""Factory for gating functions."""
if moe_gating == "top_2":
return _top_2_gating
elif moe_gating == "switch":
return _switch_gating
elif moe_gating == "ntlb":
return _ntlb_gating
elif moe_gating == "switch_max":
return _switch_max_gating
elif moe_gating == "expert_selection":
return _expert_selection_gating
else:
raise ValueError("unknown hparams.moe_gating=%s" % moe_gating)


def _ntlb_gating(inputs,
outer_expert_dims,
experts_dim,
Expand Down

0 comments on commit 16ad1b2

Please sign in to comment.