Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-layer attention weight sharing fails in different scopes #52

Open
mqyqlx opened this issue Mar 19, 2024 · 3 comments
Open

Cross-layer attention weight sharing fails in different scopes #52

mqyqlx opened this issue Mar 19, 2024 · 3 comments

Comments

@mqyqlx
Copy link

mqyqlx commented Mar 19, 2024

Hi. I try to share attention weight across layers following the testcase in shared_layers_test.py.

  def testSharedTemplateLayer(self):
    sub_params = pax_fiddle.Config(
        linears.FeedForward, input_dims=8, output_dims=8
    )
    # Only share the linear projection, not the entire FeedForward layer.
    sub_params.linear_tpl.shared_weight_layer_id = 'shared_weight'
    test_layer_p = pax_fiddle.Config(
        SimpleShared01,
        name='test',
        sub1_tpl=sub_params.clone(),
        sub2_tpl=sub_params.clone(),
    )
    x_in = jnp.ones([2, 8])
    with base_layer.JaxContext.new_context():
      prng_key = jax.random.PRNGKey(1234)
      layer = base_layer.instantiate(test_layer_p)
      init_vars = layer.init(prng_key, x_in)

But it failed to share weight because of using different scopes when set or lookup cache.

  def lookup_shared_layer(
      self, root_scope: flax_core.Scope, shared_layer_id: str
  ) -> _SharedLayerCacheEntry | None:
    logging.info('lookup_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    return self._root_scope_to_shared_layers_map[root_scope][shared_layer_id]

  def set_shared_layer(self, root_scope: flax_core.Scope, shared_layer_id: str,
                       wrapper: _WrapperLayer, layer_hparams):
    logging.info('set_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    existing = self.lookup_shared_layer(root_scope, shared_layer_id)
    assert existing is None
    self._root_scope_to_shared_layers_map[root_scope][
        shared_layer_id] = _SharedLayerCacheEntry(
            layer=wrapper.cld, hparams=layer_hparams.clone(), wrapper=wrapper)

Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.


class StackedTransformer(base_layer.BaseLayer):
  use_cross_attention: bool = False
  mask_self_attention: bool = False
  num_layers: int = 0
  model_dims: int = 0
  hidden_dims: int = 0
  num_heads: int = 0
  dim_per_head: int | None = None
  dropout_prob: float = 0.0
  atten_dropout_prob: float | None = None
  residual_dropout_prob: float | None = None
  relu_dropout_prob: float | None = None
  residual_droppath_prob: float = 0.0
  input_dropout_prob: float = 0.0
  gating_func: str = 'top2'
  unadjusted_expert_capacity_factor: float = 2.0
  transformer_layer_params_tpl: LayerTpl | Sequence[LayerTpl] = template_field(
      Transformer
  )
  packed_input: bool = False
  fold_padding_with_segment_mask: bool = False
  moe_layer_tpl: LayerTpl | None = template_field(TransformerFeedForwardMoe)
  num_experts: int = 0
  num_groups: int = 1
  min_group_size: int | None = None
  moe_layers: Sequence[int] | None = ()
  ngrammer_tpls: Sequence[LayerTpl] | None = template_field(None)
  remat: bool = False
  share_interval: int = 6
  checkpoint_policy: AutodiffCheckpointType = (
      AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1
  )

  def _clone_layer_params(self, layer_tpl: LayerTpl) -> LayerTpl:
    """Useful to let subclasses switch the class (e.g. Streaming version)."""
    return layer_tpl.clone()

  def setup(self) -> None:
    assert self.num_layers > 0
    assert self.model_dims > 0
    assert self.hidden_dims > 0
    assert self.num_heads > 0
    assert 0.0 <= self.dropout_prob < 1.0
    assert 0.0 <= self.input_dropout_prob < 1.0
    def _layer_params(i):
      """Construct i-th layer params."""
      if isinstance(self.transformer_layer_params_tpl, Sequence):
        factor = self.num_layers // len(self.transformer_layer_params_tpl)
        ii = i // factor
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
      else:
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl)
      p_i.name = f'layer_{i}'
  
      ii = i % self.share_interval  # ii is in the range [0,5] when share_interval = 6
      p_i.tr_atten_tpl.shared_weight_layer_id = f'shared_attn_{ii}'
      
      p_i.use_cross_attention = self.use_cross_attention
      p_i.num_heads = self.num_heads
      p_i.dim_per_head = self.dim_per_head
      p_i.input_dims = self.model_dims
      p_i.packed_input = self.packed_input
      p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob
      p_i.residual_dropout_prob = (
          self.residual_dropout_prob or self.dropout_prob
      )
      p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob
      p_i.hidden_dims = self.hidden_dims
      if self.residual_droppath_prob > 0.0:
        p_i.residual_droppath_prob = (
            self.residual_droppath_prob * i / max(1, self.num_layers)
        )
      if self.moe_layers and i in self.moe_layers:
        assert self.num_experts > 0
        assert self.moe_layer_tpl is not None
        moe_p = self.moe_layer_tpl.clone()
        moe_p.num_experts = self.num_experts
        moe_p.num_groups = self.num_groups
        moe_p.min_group_size = self.min_group_size
        moe_p.gating_func = self.gating_func
        if moe_p.hidden_dims:
          # MoE hidden_dims could be different from FFN hidden_dims
          p_i.hidden_dims = moe_p.hidden_dims
        p_i.tr_fflayer_tpl = moe_p
      if self.ngrammer_tpls is not None:
        if self.ngrammer_tpls[i] is not None:
          p_i.ngrammer_tpl = self.ngrammer_tpls[i]
      return p_i

    if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
      if self.num_layers % len(self.transformer_layer_params_tpl):
        raise ValueError(
            'num_layers should be divisible by transformer_layer_params_tpl'
        )

    layer_params = [_layer_params(i) for i in range(self.num_layers)]
    self.create_children('x_layers', layer_params)

    if self.input_dropout_prob > 0.0:
      self.create_child(
          'input_dropout',
          pax_fiddle.Config(
              stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob
          ),
      )

Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing?
I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.

@justzh

This comment was marked as spam.

@justzh

This comment was marked as spam.

@justzh

This comment was marked as spam.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants