Skip to content

Commit

Permalink
Megatron positional encoding alibi fix (#5808) (#5863)
Browse files Browse the repository at this point in the history
* 1. Debugging.

* 1. Debugging.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. Debugging.

* 1. Debugging.

* 1. Fixed initialization.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

* 1. Debugging.

* 1. Debugging.

* 1. Debugging.

* 1. Debugging.

* 1. Debugging.

* 1. Debugging.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. Debugging.

* 1. Removed scale from ALiBi.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated yaml and added support to control number of alibi heads.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. Removed num_attention_heads_alibi from configs.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Signed-off-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>

Signed-off-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
  • Loading branch information
4 people authored and ericharper committed Jan 31, 2023
1 parent eed8900 commit 1af2f52
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ init_method_std: 0.02 # Standard deviation of the zero mean normal distribution
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
ffn_dropout: 0.0 # Dropout probability in the feed-forward layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative', 'alibi']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
relative_position_bias_self_attention_only: True # whether to only use relative position bias for self attention only.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def get_slopes_power_of_2(n):
return slopes


def build_slopes(num_attention_heads, alibi_num_heads):
def build_slopes(num_attention_heads, num_attention_heads_alibi):
"""
Builds a slopes tensor.
"""
slopes = torch.Tensor(get_slopes(alibi_num_heads) + [0] * (num_attention_heads - alibi_num_heads)).cuda()
slopes = torch.Tensor(
get_slopes(num_attention_heads_alibi) + [0] * (num_attention_heads - num_attention_heads_alibi)
).cuda()
return slopes.unsqueeze(-1).unsqueeze(-1)


Expand All @@ -65,23 +67,25 @@ class ALiBiRelativePositionEmbedding(torch.nn.Module):
Based on https://arxiv.org/bas/2108.12409
"""

def __init__(self, bidirectional, num_attention_heads, layer_type, alibi_num_heads=None, max_seq_len=512):
def __init__(
self, bidirectional, num_attention_heads, layer_type, num_attention_heads_alibi=None, max_seq_len=512
):
"""
Args:
bidirectional: Whether to use bidirectional relative position embedding
num_attention_heads: Number of attention heads
layer_type: Layer type. Can be one of [LayerType.encoder or LayerType.decoder]. Willdetermine the bias construction
alibi_num_heads: Number of attention heads for which alibi bias will be used
num_attention_heads_alibi: Number of attention heads for which alibi bias will be used
max_seq_len: Maximum sequence length for precomputed relative positions. Larger sizes will result in more memory usage by computing alibi mask on-the-fly.
"""
super().__init__()

if alibi_num_heads is None:
alibi_num_heads = num_attention_heads
if (num_attention_heads_alibi is None) or (num_attention_heads_alibi <= 0):
num_attention_heads_alibi = num_attention_heads

if alibi_num_heads > num_attention_heads:
if num_attention_heads_alibi > num_attention_heads:
raise ValueError(
f"alibi_num_heads ({alibi_num_heads}) cannot be larger than num_attention_heads ({num_attention_heads})"
f"num_attention_heads_alibi ({num_attention_heads_alibi}) cannot be larger than num_attention_heads ({num_attention_heads})"
)

self.bidirectional = bidirectional
Expand All @@ -90,12 +94,12 @@ def __init__(self, bidirectional, num_attention_heads, layer_type, alibi_num_hea
self.layer_type = layer_type
# define the size of pre-computed relative position slopes.
# define the number of attention heads for which alibi mask will be pre-computed (the rest are disabled).
self.alibi_num_heads = alibi_num_heads
self.num_attention_heads_alibi = num_attention_heads_alibi
# Larger sizes will result in more memory usage by computing alibi mask on-the-fly.
self.max_seq_len = max_seq_len

# cache the slopes
self.slopes = build_slopes(num_attention_heads, alibi_num_heads)
self.slopes = build_slopes(num_attention_heads, num_attention_heads_alibi)
# cache the relative position bias. shape (num_attention_heads, max_seq_len, max_seq_len)
self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads)

Expand All @@ -113,4 +117,4 @@ def forward(self, query_seq_length, key_seq_length):
relative_position = torch.tril(relative_position)

# shape (1, num_heads, query_length, key_length)
return relative_position.unsqueeze(0) * self.slopes
return -relative_position.unsqueeze(0) * self.slopes
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
bidirectional=True,
num_attention_heads=encoder_cfg.num_attention_heads,
layer_type=LayerType.encoder,
alibi_num_heads=None,
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding"
Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(
bidirectional=False,
num_attention_heads=decoder_cfg.num_attention_heads,
layer_type=LayerType.decoder,
alibi_num_heads=None,
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding"
Expand Down

0 comments on commit 1af2f52

Please sign in to comment.