You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi. I try to share attention weight across layers following the testcase in shared_layers_test.py.
deftestSharedTemplateLayer(self):
sub_params=pax_fiddle.Config(
linears.FeedForward, input_dims=8, output_dims=8
)
# Only share the linear projection, not the entire FeedForward layer.sub_params.linear_tpl.shared_weight_layer_id='shared_weight'test_layer_p=pax_fiddle.Config(
SimpleShared01,
name='test',
sub1_tpl=sub_params.clone(),
sub2_tpl=sub_params.clone(),
)
x_in=jnp.ones([2, 8])
withbase_layer.JaxContext.new_context():
prng_key=jax.random.PRNGKey(1234)
layer=base_layer.instantiate(test_layer_p)
init_vars=layer.init(prng_key, x_in)
But it failed to share weight because of using different scopes when set or lookup cache.
deflookup_shared_layer(
self, root_scope: flax_core.Scope, shared_layer_id: str
) ->_SharedLayerCacheEntry|None:
logging.info('lookup_shared_layer called with id: %s in the scope of %s',
shared_layer_id, root_scope)
returnself._root_scope_to_shared_layers_map[root_scope][shared_layer_id]
defset_shared_layer(self, root_scope: flax_core.Scope, shared_layer_id: str,
wrapper: _WrapperLayer, layer_hparams):
logging.info('set_shared_layer called with id: %s in the scope of %s',
shared_layer_id, root_scope)
existing=self.lookup_shared_layer(root_scope, shared_layer_id)
assertexistingisNoneself._root_scope_to_shared_layers_map[root_scope][
shared_layer_id] =_SharedLayerCacheEntry(
layer=wrapper.cld, hparams=layer_hparams.clone(), wrapper=wrapper)
Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.
class StackedTransformer(base_layer.BaseLayer):
use_cross_attention: bool = False
mask_self_attention: bool = False
num_layers: int = 0
model_dims: int = 0
hidden_dims: int = 0
num_heads: int = 0
dim_per_head: int | None = None
dropout_prob: float = 0.0
atten_dropout_prob: float | None = None
residual_dropout_prob: float | None = None
relu_dropout_prob: float | None = None
residual_droppath_prob: float = 0.0
input_dropout_prob: float = 0.0
gating_func: str = 'top2'
unadjusted_expert_capacity_factor: float = 2.0
transformer_layer_params_tpl: LayerTpl | Sequence[LayerTpl] = template_field(
Transformer
)
packed_input: bool = False
fold_padding_with_segment_mask: bool = False
moe_layer_tpl: LayerTpl | None = template_field(TransformerFeedForwardMoe)
num_experts: int = 0
num_groups: int = 1
min_group_size: int | None = None
moe_layers: Sequence[int] | None = ()
ngrammer_tpls: Sequence[LayerTpl] | None = template_field(None)
remat: bool = False
share_interval: int = 6
checkpoint_policy: AutodiffCheckpointType = (
AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1
)
def _clone_layer_params(self, layer_tpl: LayerTpl) -> LayerTpl:
"""Useful to let subclasses switch the class (e.g. Streaming version)."""
return layer_tpl.clone()
def setup(self) -> None:
assert self.num_layers > 0
assert self.model_dims > 0
assert self.hidden_dims > 0
assert self.num_heads > 0
assert 0.0 <= self.dropout_prob < 1.0
assert 0.0 <= self.input_dropout_prob < 1.0
def _layer_params(i):
"""Construct i-th layer params."""
if isinstance(self.transformer_layer_params_tpl, Sequence):
factor = self.num_layers // len(self.transformer_layer_params_tpl)
ii = i // factor
p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
else:
p_i = self._clone_layer_params(self.transformer_layer_params_tpl)
p_i.name = f'layer_{i}'
ii = i % self.share_interval # ii is in the range [0,5] when share_interval = 6
p_i.tr_atten_tpl.shared_weight_layer_id = f'shared_attn_{ii}'
p_i.use_cross_attention = self.use_cross_attention
p_i.num_heads = self.num_heads
p_i.dim_per_head = self.dim_per_head
p_i.input_dims = self.model_dims
p_i.packed_input = self.packed_input
p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob
p_i.residual_dropout_prob = (
self.residual_dropout_prob or self.dropout_prob
)
p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob
p_i.hidden_dims = self.hidden_dims
if self.residual_droppath_prob > 0.0:
p_i.residual_droppath_prob = (
self.residual_droppath_prob * i / max(1, self.num_layers)
)
if self.moe_layers and i in self.moe_layers:
assert self.num_experts > 0
assert self.moe_layer_tpl is not None
moe_p = self.moe_layer_tpl.clone()
moe_p.num_experts = self.num_experts
moe_p.num_groups = self.num_groups
moe_p.min_group_size = self.min_group_size
moe_p.gating_func = self.gating_func
if moe_p.hidden_dims:
# MoE hidden_dims could be different from FFN hidden_dims
p_i.hidden_dims = moe_p.hidden_dims
p_i.tr_fflayer_tpl = moe_p
if self.ngrammer_tpls is not None:
if self.ngrammer_tpls[i] is not None:
p_i.ngrammer_tpl = self.ngrammer_tpls[i]
return p_i
if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
if self.num_layers % len(self.transformer_layer_params_tpl):
raise ValueError(
'num_layers should be divisible by transformer_layer_params_tpl'
)
layer_params = [_layer_params(i) for i in range(self.num_layers)]
self.create_children('x_layers', layer_params)
if self.input_dropout_prob > 0.0:
self.create_child(
'input_dropout',
pax_fiddle.Config(
stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob
),
)
Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing?
I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.
The text was updated successfully, but these errors were encountered:
Hi. I try to share attention weight across layers following the testcase in shared_layers_test.py.
But it failed to share weight because of using different scopes when set or lookup cache.
Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.
Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing?
I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.
The text was updated successfully, but these errors were encountered: