Skip to content

Commit

Permalink
Merge branch 'huvu/t5_fixes_updates' into 'main'
Browse files Browse the repository at this point in the history
Updating T5 codes to fix bugs

See merge request ADLR/megatron-lm!2471
  • Loading branch information
ericharper committed Dec 30, 2024
2 parents 2da43ef + 48103f4 commit 076972e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
10 changes: 6 additions & 4 deletions megatron/core/models/T5/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forward(
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)

Expand All @@ -330,9 +330,11 @@ def forward(
)

if self.post_process:
lm_logits = self.lm_head(
decoder_hidden_states, self.shared_embedding_or_output_weight()
)
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)

if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous()
Expand Down
1 change: 1 addition & 0 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
cp_comm_type=cp_comm_type,
softmax_scale=self.config.softmax_scale,
)

self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class TransformerConfig(ModelParallelConfig):
If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """

softmax_scale: float = None
"""Softmax scale for attention scaling."""

num_query_groups: int = None
"""Number of query groups for group query attention. If None, normal attention is used."""

Expand Down

0 comments on commit 076972e

Please sign in to comment.