Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splitting tokens when routing #326

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 191 additions & 148 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(self,
ntlb_top_k=4,
output_dim=None,
use_experts_attention=False,
z_loss=None):
z_loss=None,
num_hidden_splits=None,
split_hidden_before_routing=False):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand All @@ -85,7 +87,9 @@ def __init__(self,
moe_output_dim=output_dim,
moe_ntlb_top_k=ntlb_top_k,
moe_use_experts_attention=use_experts_attention,
moe_z_loss=z_loss)
moe_z_loss=z_loss,
moe_num_hidden_splits=num_hidden_splits,
moe_split_hidden_before_routing=split_hidden_before_routing)
self._activation = activation

def call(self, context, x, losses=None):
Expand Down Expand Up @@ -327,8 +331,8 @@ def transformer_moe_layer_v1(
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])
batch_and_length_dims, orig_input_dim = (
orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1])
# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
Expand All @@ -348,16 +352,57 @@ def transformer_moe_layer_v1(

n = n // outer_batch_dim.size

mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
orig_batch_dim)
num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
mesh_dim_size)
# Create num_groups and group_size dimensions
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(
layout, mesh_shape, orig_batch_dim)
num_groups, group_size = _split_into_groups(
n, hparams.moe_group_size, mesh_dim_size)
orig_group_size_dim = mtf.Dimension("group", group_size)
orig_num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

# The original dimensions correspond to those before splitting tokens
# into subtokens
group_size_dim = orig_group_size_dim
num_groups_dim = orig_num_groups_dim
input_dim = orig_input_dim

split_hidden_before_routing = False
split_hidden_after_routing = False
if hparams.moe_num_hidden_splits is not None:
if orig_input_dim.size % hparams.moe_num_hidden_splits:
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
hparams.moe_num_hidden_splits, input_dim.size))
if output_dim.size % hparams.moe_num_hidden_splits:
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
hparams.moe_num_hidden_splits, input_dim.size))
split_hidden_before_routing = hparams.moe_split_hidden_before_routing
split_hidden_after_routing = not hparams.moe_split_hidden_before_routing
hidden_dim = mtf.Dimension(
"expert_hidden",
hparams.moe_hidden_size // hparams.moe_num_hidden_splits)
sub_output_dim = mtf.Dimension(
output_dim.name, output_dim.size // hparams.moe_num_hidden_splits)
num_splits_dim = mtf.Dimension(
"num_splits", hparams.moe_num_hidden_splits)

if split_hidden_before_routing:
input_dim = mtf.Dimension(
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)

# Split into groups and subtokens
inputs = mtf.reshape(
inputs, [outer_batch_dim, num_groups_dim, group_size_dim,
num_splits_dim, input_dim])

group_size_dim = mtf.Dimension("group", group_size)
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
inputs = mtf.transpose(
inputs, [outer_batch_dim, num_groups_dim, num_splits_dim,
group_size_dim, input_dim])

num_groups_dim = mtf.Dimension(
orig_batch_dim.name, num_groups * hparams.moe_num_hidden_splits)

# [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
# OGSM Tensor
inputs = mtf.reshape(inputs, moe_input_dims)

# Each sequence sends expert_capacity positions to each expert.
Expand All @@ -373,156 +418,138 @@ def transformer_moe_layer_v1(
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)

if nonpadding is not None:
nonpadding = mtf.zeros(
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding

if split_hidden_before_routing:
nonpadding = mtf.reshape(
nonpadding,
[outer_batch_dim, orig_num_groups_dim, orig_group_size_dim])

# Tile num_hidden_splits times with an einsum
tiling_tensor = mtf.ones(inputs.mesh, [num_splits_dim])
nonpadding = mtf.einsum(
[nonpadding, tiling_tensor],
output_shape=[outer_batch_dim, orig_num_groups_dim, num_splits_dim,
orig_group_size_dim])

nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
if hparams.moe_gating == "top_2":
# combine_tensor,
# dispatch_tensor OG`SEC Tensors
# (G is generally split along mesh dim)
dispatch_tensor, combine_tensor, loss = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = _switch_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
group_size_dim=group_size_dim,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

expert_inputs = mtf.einsum([inputs, dispatch_tensor],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim, input_dim
]))
# [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
gating_fn = get_gating_fn(hparams.moe_gating)
dispatch_tensor, combine_tensor, loss = gating_fn(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)

# Dispatch to the experts by reducing group_size_dim
# inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
# dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
# expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B,
# expert_capacity_dim, input_dim]
expert_inputs_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, input_dim]
expert_inputs = mtf.einsum([inputs, dispatch_tensor], expert_inputs_shape)

# Split over batch -> split over experts
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
# expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
# expert_capacity_dim, input_dim or input_dim.M]
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
d_model_split_dim
]))

# Split over batch -> split over experts
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
input_dim
]))

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)

def _compute_output(hidden, layer_name):
"""Compute the output of the attention layer from the hidden vector."""
expert_inputs_shape = [
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim, d_model_split_dim]
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)

expert_inputs_shape = [
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim, input_dim]
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)

def _apply_experts(x, output_dim, hidden_dim):
# x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
# expert_capacity_dim, input_dim]
h = mtf.layers.dense_product(
x,
reduced_dims=x.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
expert_output = mtf.layers.dense(
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
name=layer_name)

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))

# Split over experts -> split over batch
h, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
name="wo")

return expert_output

if split_hidden_after_routing:
input_dim = mtf.Dimension(
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)
expert_inputs = mtf.reshape(
expert_inputs, expert_inputs.shape[:-1] + [num_splits_dim, input_dim])
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
# Concat sub_tokens into tokens
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output

if hparams.moe_use_experts_attention:
# We share k_h and v_h with no degradation in performance
q_h, k_h = h, h
outputs = []
q = _compute_output(q_h, layer_name="q_wo")
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
expert_output, expert_output.shape[:-2] + [output_dim])
elif split_hidden_before_routing:
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
expert_output = _apply_experts(expert_inputs, output_dim, hidden_dim)

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim]
expert_output = mtf.reshape(expert_output, expert_output_shape)

# Split over experts -> split over batch
expert_output_shape = [
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, expert_output.shape[-1]]
expert_output = mtf.reshape(expert_output, expert_output_shape)

# Combine by reducing experts_dim_unsplit and expert_capacity_dim
# expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim,
# expert_capacity_dim, output_dim]
# combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
# experts_dim_unsplit, expert_capacity_dim]
# output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
moe_output_dims = moe_input_dims[:-1] + [expert_output.shape[-1]]
output = mtf.einsum([expert_output, combine_tensor], moe_output_dims)
# import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top

if split_hidden_before_routing:
output = mtf.reshape(
output, [output.shape[0], orig_num_groups_dim, num_splits_dim] + (
output.shape[-2:]))
output = mtf.transpose(
output, output.shape[:2] + [
group_size_dim, num_splits_dim, output.shape[-1]])
output = mtf.reshape(output, output.shape[:3] + [output_dim])

output = mtf.reshape(output, batch_and_length_dims + [output_dim])

return output, loss * hparams.moe_loss_coef


def transformer_moe_layer_v2(
Expand Down Expand Up @@ -801,6 +828,22 @@ def transformer_moe_layer_v2(
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef


def get_gating_fn(moe_gating):
"""Factory for gating functions."""
if moe_gating == "top_2":
return _top_2_gating
elif moe_gating == "switch":
return _switch_gating
elif moe_gating == "ntlb":
return _ntlb_gating
elif moe_gating == "switch_max":
return _switch_max_gating
elif moe_gating == "expert_selection":
return _expert_selection_gating
else:
raise ValueError("unknown hparams.moe_gating=%s" % moe_gating)


def _ntlb_gating(inputs,
outer_expert_dims,
experts_dim,
Expand Down