From 16ad1b2d368b092c0179d8e8fb7865f763fc6382 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Tue, 8 Jun 2021 21:17:36 -0700 Subject: [PATCH] Splitting tokens when routing PiperOrigin-RevId: 378316002 --- mesh_tensorflow/transformer/moe.py | 339 ++++++++++++++++------------- 1 file changed, 191 insertions(+), 148 deletions(-) diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index 7e6f784e..2d83913b 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -61,7 +61,9 @@ def __init__(self, ntlb_top_k=4, output_dim=None, use_experts_attention=False, - z_loss=None): + z_loss=None, + num_hidden_splits=None, + split_hidden_before_routing=False): self._hparams = HParams( moe_gating=moe_gating, moe_num_experts=num_experts, @@ -85,7 +87,9 @@ def __init__(self, moe_output_dim=output_dim, moe_ntlb_top_k=ntlb_top_k, moe_use_experts_attention=use_experts_attention, - moe_z_loss=z_loss) + moe_z_loss=z_loss, + moe_num_hidden_splits=num_hidden_splits, + moe_split_hidden_before_routing=split_hidden_before_routing) self._activation = activation def call(self, context, x, losses=None): @@ -327,8 +331,8 @@ def transformer_moe_layer_v1( # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. - batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], - orig_inputs.shape.dims[-1]) + batch_and_length_dims, orig_input_dim = ( + orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim @@ -348,16 +352,57 @@ def transformer_moe_layer_v1( n = n // outer_batch_dim.size - mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, - orig_batch_dim) - num_groups, group_size = _split_into_groups(n, hparams.moe_group_size, - mesh_dim_size) + # Create num_groups and group_size dimensions + mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size( + layout, mesh_shape, orig_batch_dim) + num_groups, group_size = _split_into_groups( + n, hparams.moe_group_size, mesh_dim_size) + orig_group_size_dim = mtf.Dimension("group", group_size) + orig_num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) + + # The original dimensions correspond to those before splitting tokens + # into subtokens + group_size_dim = orig_group_size_dim + num_groups_dim = orig_num_groups_dim + input_dim = orig_input_dim + + split_hidden_before_routing = False + split_hidden_after_routing = False + if hparams.moe_num_hidden_splits is not None: + if orig_input_dim.size % hparams.moe_num_hidden_splits: + raise ValueError("num_hidden_splits {} must divide input_dim {}".format( + hparams.moe_num_hidden_splits, input_dim.size)) + if output_dim.size % hparams.moe_num_hidden_splits: + raise ValueError("num_hidden_splits {} must divide input_dim {}".format( + hparams.moe_num_hidden_splits, input_dim.size)) + split_hidden_before_routing = hparams.moe_split_hidden_before_routing + split_hidden_after_routing = not hparams.moe_split_hidden_before_routing + hidden_dim = mtf.Dimension( + "expert_hidden", + hparams.moe_hidden_size // hparams.moe_num_hidden_splits) + sub_output_dim = mtf.Dimension( + output_dim.name, output_dim.size // hparams.moe_num_hidden_splits) + num_splits_dim = mtf.Dimension( + "num_splits", hparams.moe_num_hidden_splits) + + if split_hidden_before_routing: + input_dim = mtf.Dimension( + input_dim.name, input_dim.size // hparams.moe_num_hidden_splits) + + # Split into groups and subtokens + inputs = mtf.reshape( + inputs, [outer_batch_dim, num_groups_dim, group_size_dim, + num_splits_dim, input_dim]) - group_size_dim = mtf.Dimension("group", group_size) - num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) + inputs = mtf.transpose( + inputs, [outer_batch_dim, num_groups_dim, num_splits_dim, + group_size_dim, input_dim]) + num_groups_dim = mtf.Dimension( + orig_batch_dim.name, num_groups * hparams.moe_num_hidden_splits) + + # [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim] moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim] - # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Each sequence sends expert_capacity positions to each expert. @@ -373,156 +418,138 @@ def transformer_moe_layer_v1( expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) + if nonpadding is not None: nonpadding = mtf.zeros( inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding + + if split_hidden_before_routing: + nonpadding = mtf.reshape( + nonpadding, + [outer_batch_dim, orig_num_groups_dim, orig_group_size_dim]) + + # Tile num_hidden_splits times with an einsum + tiling_tensor = mtf.ones(inputs.mesh, [num_splits_dim]) + nonpadding = mtf.einsum( + [nonpadding, tiling_tensor], + output_shape=[outer_batch_dim, orig_num_groups_dim, num_splits_dim, + orig_group_size_dim]) + nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) - if hparams.moe_gating == "top_2": - # combine_tensor, - # dispatch_tensor OG`SEC Tensors - # (G is generally split along mesh dim) - dispatch_tensor, combine_tensor, loss = _top_2_gating( - inputs=inputs, - outer_expert_dims=None, - experts_dim=experts_dim_unsplit, - expert_capacity_dim=expert_capacity_dim, - hparams=hparams, - train=train, - variable_dtype=variable_dtype, - importance=nonpadding, - num_microbatches=num_microbatches) - elif hparams.moe_gating == "switch": - dispatch_tensor, combine_tensor, loss = _switch_gating( - inputs=inputs, - outer_expert_dims=None, - experts_dim=experts_dim_unsplit, - expert_capacity_dim=expert_capacity_dim, - hparams=hparams, - train=train, - variable_dtype=variable_dtype, - importance=nonpadding, - num_microbatches=num_microbatches) - elif hparams.moe_gating == "ntlb": - dispatch_tensor, combine_tensor, loss = _ntlb_gating( - inputs=inputs, - outer_expert_dims=None, - experts_dim=experts_dim_unsplit, - expert_capacity_dim=expert_capacity_dim, - hparams=hparams, - train=train, - variable_dtype=variable_dtype, - importance=nonpadding, - num_microbatches=num_microbatches) - elif hparams.moe_gating == "switch_max": - dispatch_tensor, combine_tensor, loss = _switch_max_gating( - inputs=inputs, - outer_expert_dims=None, - experts_dim=experts_dim_unsplit, - expert_capacity_dim=expert_capacity_dim, - hparams=hparams, - train=train, - variable_dtype=variable_dtype, - importance=nonpadding, - num_microbatches=num_microbatches) - elif hparams.moe_gating == "expert_selection": - dispatch_tensor, combine_tensor, loss = _expert_selection_gating( - inputs=inputs, - outer_expert_dims=None, - experts_dim=experts_dim_unsplit, - group_size_dim=group_size_dim, - expert_capacity_dim=expert_capacity_dim, - hparams=hparams, - train=train, - variable_dtype=variable_dtype, - importance=nonpadding, - name="expert_selection_gating", - num_microbatches=num_microbatches) - else: - raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) - expert_inputs = mtf.einsum([inputs, dispatch_tensor], - mtf.Shape([ - outer_batch_dim, experts_dim_unsplit, - num_groups_dim, expert_capacity_dim, input_dim - ])) + # [outer_batch_dim, num_groups_dim.B, group_size_dim, + # experts_dim_unsplit, expert_capacity_dim] + gating_fn = get_gating_fn(hparams.moe_gating) + dispatch_tensor, combine_tensor, loss = gating_fn( + inputs=inputs, + outer_expert_dims=None, + experts_dim=experts_dim_unsplit, + expert_capacity_dim=expert_capacity_dim, + hparams=hparams, + train=train, + variable_dtype=variable_dtype, + importance=nonpadding, + num_microbatches=num_microbatches) + + # Dispatch to the experts by reducing group_size_dim + # inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim] + # dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim, + # experts_dim_unsplit, expert_capacity_dim] + # expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B, + # expert_capacity_dim, input_dim] + expert_inputs_shape = [ + outer_batch_dim, experts_dim_unsplit, num_groups_dim, + expert_capacity_dim, input_dim] + expert_inputs = mtf.einsum([inputs, dispatch_tensor], expert_inputs_shape) + # Split over batch -> split over experts # Extra reshape reduces communication cost for model-parallel versions. # For model-parallel versions, this reshape causes an mtf.slice and for non- # model-parallel versions, this has no effect. + # expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit, + # expert_capacity_dim, input_dim or input_dim.M] d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size) - expert_inputs = mtf.reshape( - expert_inputs, - mtf.Shape([ - outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, - d_model_split_dim - ])) - - # Split over batch -> split over experts - expert_inputs = mtf.reshape( - expert_inputs, - mtf.Shape([ - outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, - input_dim - ])) - - # Now feed the expert inputs through the experts. - h = mtf.layers.dense_product( - expert_inputs, - reduced_dims=expert_inputs.shape.dims[-1:], - new_dims=[hidden_dim], - expert_dims=[experts_dim], - activation_functions=activation, use_bias=False, - variable_dtype=variable_dtype, name="wi") - - if hparams.moe_dropout_rate != 0.0: - h = mtf.dropout(h, is_training=train, - keep_prob=1.0 - hparams.moe_dropout_rate) - - def _compute_output(hidden, layer_name): - """Compute the output of the attention layer from the hidden vector.""" + expert_inputs_shape = [ + outer_batch_dim, experts_dim, batch_dim_unsplit, + expert_capacity_dim, d_model_split_dim] + expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape) + + expert_inputs_shape = [ + outer_batch_dim, experts_dim, batch_dim_unsplit, + expert_capacity_dim, input_dim] + expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape) + + def _apply_experts(x, output_dim, hidden_dim): + # x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit, + # expert_capacity_dim, input_dim] + h = mtf.layers.dense_product( + x, + reduced_dims=x.shape.dims[-1:], + new_dims=[hidden_dim], + expert_dims=[experts_dim], + activation_functions=activation, use_bias=False, + variable_dtype=variable_dtype, name="wi") + + if hparams.moe_dropout_rate != 0.0: + h = mtf.dropout(h, is_training=train, + keep_prob=1.0 - hparams.moe_dropout_rate) expert_output = mtf.layers.dense( - hidden, output_dim, expert_dims=[experts_dim], use_bias=False, - reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype, - name=layer_name) - - # Extra reshape reduces communication cost for model-parallel versions. - # For model-parallel versions, this reshape causes an mtf.slice and for non- - # model-parallel versions, this has no effect. - expert_output = mtf.reshape( - expert_output, - mtf.Shape([ - outer_batch_dim, experts_dim_unsplit, num_groups_dim, - expert_capacity_dim, d_model_split_dim - ])) - - # Split over experts -> split over batch + h, output_dim, expert_dims=[experts_dim], use_bias=False, + reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype, + name="wo") + + return expert_output + + if split_hidden_after_routing: + input_dim = mtf.Dimension( + input_dim.name, input_dim.size // hparams.moe_num_hidden_splits) + expert_inputs = mtf.reshape( + expert_inputs, expert_inputs.shape[:-1] + [num_splits_dim, input_dim]) + expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim) + # Concat sub_tokens into tokens expert_output = mtf.reshape( - expert_output, - mtf.Shape([ - outer_batch_dim, - experts_dim_unsplit, - num_groups_dim, - expert_capacity_dim, - output_dim, - ])) - moe_output_dims = moe_input_dims[:-1] + [output_dim] - output = mtf.einsum([expert_output, combine_tensor], - mtf.Shape(moe_output_dims)) - output = mtf.reshape(output, batch_and_length_dims + [output_dim]) - return output - - if hparams.moe_use_experts_attention: - # We share k_h and v_h with no degradation in performance - q_h, k_h = h, h - outputs = [] - q = _compute_output(q_h, layer_name="q_wo") - k = _compute_output(k_h, layer_name="k_wo") - outputs.append(q) - outputs.append(k) - return outputs, loss * hparams.moe_loss_coef + expert_output, expert_output.shape[:-2] + [output_dim]) + elif split_hidden_before_routing: + expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim) else: - output = _compute_output(h, layer_name="wo") - return output, loss * hparams.moe_loss_coef + expert_output = _apply_experts(expert_inputs, output_dim, hidden_dim) + + # Extra reshape reduces communication cost for model-parallel versions. + # For model-parallel versions, this reshape causes an mtf.slice and for non- + # model-parallel versions, this has no effect. + expert_output_shape = [ + outer_batch_dim, experts_dim_unsplit, num_groups_dim, + expert_capacity_dim, d_model_split_dim] + expert_output = mtf.reshape(expert_output, expert_output_shape) + + # Split over experts -> split over batch + expert_output_shape = [ + outer_batch_dim, experts_dim_unsplit, num_groups_dim, + expert_capacity_dim, expert_output.shape[-1]] + expert_output = mtf.reshape(expert_output, expert_output_shape) + + # Combine by reducing experts_dim_unsplit and expert_capacity_dim + # expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim, + # expert_capacity_dim, output_dim] + # combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim, + # experts_dim_unsplit, expert_capacity_dim] + # output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim] + moe_output_dims = moe_input_dims[:-1] + [expert_output.shape[-1]] + output = mtf.einsum([expert_output, combine_tensor], moe_output_dims) + # import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top + + if split_hidden_before_routing: + output = mtf.reshape( + output, [output.shape[0], orig_num_groups_dim, num_splits_dim] + ( + output.shape[-2:])) + output = mtf.transpose( + output, output.shape[:2] + [ + group_size_dim, num_splits_dim, output.shape[-1]]) + output = mtf.reshape(output, output.shape[:3] + [output_dim]) + + output = mtf.reshape(output, batch_and_length_dims + [output_dim]) + + return output, loss * hparams.moe_loss_coef def transformer_moe_layer_v2( @@ -801,6 +828,22 @@ def transformer_moe_layer_v2( return output, (loss_outer + loss_inner) * hparams.moe_loss_coef +def get_gating_fn(moe_gating): + """Factory for gating functions.""" + if moe_gating == "top_2": + return _top_2_gating + elif moe_gating == "switch": + return _switch_gating + elif moe_gating == "ntlb": + return _ntlb_gating + elif moe_gating == "switch_max": + return _switch_max_gating + elif moe_gating == "expert_selection": + return _expert_selection_gating + else: + raise ValueError("unknown hparams.moe_gating=%s" % moe_gating) + + def _ntlb_gating(inputs, outer_expert_dims, experts_dim,