Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron KERPLE positional embeddings #6478

Merged
merged 25 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b9a9c40
[TTS] FastPitch adapter fine-tune and conditional layer normalization…
hsiehjackson Apr 17, 2023
14e9668
[TTS] whitelist broken path fix. (#6412)
XuesongYang Apr 17, 2023
536ee62
[TTS] FastPitch speaker encoder (#6417)
hsiehjackson Apr 18, 2023
ceb539f
Sharded manifests for tarred datasets (#6395)
bmwshop Apr 18, 2023
499a3b2
Update wfst_text_normalization.rst (#6374)
jimregan Apr 18, 2023
a365879
Support Swiglu in TP PP Conversion (#6437) (#6451)
github-actions[bot] Apr 19, 2023
be711c9
Update NeMo_TTS_Primer.ipynb (#6436)
pythinker Apr 19, 2023
9e72326
add rampup batch size support for Megatron GPT (#6424)
dimapihtar Apr 20, 2023
41fcf4d
Meagtron encoder decoder fix for empty validation outputs (#6459) (#6…
github-actions[bot] Apr 20, 2023
77f0959
Code-Switching dataset creation - upgrading to aggregate tokenizer ma…
KunalDhawan Apr 21, 2023
2822ff3
Added/updated new Conformer configs (#6426) (#6467)
github-actions[bot] Apr 21, 2023
244ba8d
Update script for ngram rnnt and hat beam search decoding (#6370)
andrusenkoau Apr 21, 2023
094cbae
BERT pre-training mp fork to spawn (#6442) (#6454)
github-actions[bot] Apr 22, 2023
daa9744
fix replace_bos_with_pad not found (#6443) (#6450)
github-actions[bot] Apr 22, 2023
557c4b7
reduce workers on NMT CI (#6472) (#6474)
github-actions[bot] Apr 22, 2023
690742b
1. Added KERPLE positional embeddings to encoder-decoder.
michalivne Apr 23, 2023
8664d09
Merge branch 'main' into megatron-kerple-positional-embeddings
michalivne Apr 23, 2023
ed4c373
1. Added a missing file.
michalivne Apr 23, 2023
e3ca438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2023
c6fa1a9
1. Fixing commits.
michalivne Apr 23, 2023
f482074
Merge branch 'megatron-kerple-positional-embeddings' of github.com:NV…
michalivne Apr 23, 2023
f6ed850
1. Debugging.
michalivne Apr 23, 2023
27cf8de
1. Debugging.
michalivne Apr 23, 2023
0f593b8
1. Debugging.
michalivne Apr 23, 2023
9e84e42
1. Debugging.
michalivne Apr 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3806,6 +3806,102 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/t5_index_mappings"
}
}
stage('L2: Megatron T5 with KERPLE Pretraining and Resume Training TP=2') {
when {
anyOf {
branch 'r1.18.0'
changeRequest target: 'r1.18.0'
}
}
failFast true
steps {
sh "python examples/nlp/language_modeling/megatron_t5_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
model.tensor_model_parallel_size=2 \
model.seq_length=128 \
model.encoder.num_layers=4 \
model.encoder.hidden_size=64 \
model.encoder.num_attention_heads=8 \
model.encoder.activation='swiglu' \
model.encoder.masked_softmax_fusion=False \
model.encoder.bias_activation_fusion=False \
model.encoder.activations_checkpoint_method='block' \
model.encoder.activations_checkpoint_num_layers=1 \
model.encoder.position_embedding_type=kerple \
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
model.decoder.activations_checkpoint_num_layers=1 \
model.encoder.transformer_block_type='pre_ln' \
model.decoder.transformer_block_type='pre_ln' \
model.data.data_prefix=[.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src,.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref] \
model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings \
model.data.data_impl=text_mmap \
+model.data.data_impl_kwargs.newline_int=10 \
+model.data.data_impl_kwargs.header_lines=0 \
+model.data.data_impl_kwargs.workers=null \
+model.data.data_impl_kwargs.sort_dataset_paths=False \
model.share_token_embeddings=False \
model.share_decoder_tokens_head_embeddings=False"
sh "python examples/nlp/language_modeling/megatron_t5_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
exp_manager.resume_if_exists=True \
model.tensor_model_parallel_size=2 \
model.seq_length=128 \
model.encoder.num_layers=4 \
model.encoder.hidden_size=64 \
model.encoder.num_attention_heads=8 \
model.encoder.activation='swiglu' \
model.encoder.masked_softmax_fusion=False \
model.encoder.bias_activation_fusion=False \
model.encoder.activations_checkpoint_method='block' \
model.encoder.activations_checkpoint_num_layers=1 \
model.encoder.position_embedding_type=kerple \
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
model.decoder.activations_checkpoint_num_layers=1 \
model.encoder.transformer_block_type='pre_ln' \
model.decoder.transformer_block_type='pre_ln' \
model.data.data_prefix=[.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src,.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref] \
model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings \
model.data.data_impl=text_mmap \
+model.data.data_impl_kwargs.newline_int=10 \
+model.data.data_impl_kwargs.header_lines=0 \
+model.data.data_impl_kwargs.workers=null \
+model.data.data_impl_kwargs.sort_dataset_paths=False \
model.share_token_embeddings=False \
model.share_decoder_tokens_head_embeddings=False"
sh "rm -rf examples/nlp/language_modeling/t5_pretrain_results"
sh "rm -rf examples/nlp/language_modeling/t5_index_mappings"
}
}
stage('L2: Megatron T5 Pretraining and Resume Training PP=2') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ init_method_std: 0.02 # Standard deviation of the zero mean normal distribution
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
ffn_dropout: 0.0 # Dropout probability in the feed-forward layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative', 'alibi']
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative', 'alibi', 'kerple']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
relative_position_bias_self_attention_only: True # whether to only use relative position bias for self attention only.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch

from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import (
build_relative_position,
build_slopes,
)

__all__ = ['KERPLERelativePositionEmbedding']


class KERPLERelativePositionEmbedding(torch.nn.Module):
"""
kerple (Attention with Linear Biases) relative position embedding for auto-regressive decoder
and joint encoder (symmetric for forward and backward distance).
Based on https://arxiv.org/bas/2108.12409
"""

def __init__(
self, bidirectional, num_attention_heads, layer_type, num_attention_heads_kerple=None, max_seq_len=512
):
"""
Args:
bidirectional: Whether to use bidirectional relative position embedding
num_attention_heads: Number of attention heads
layer_type: Layer type. Can be one of [LayerType.encoder or LayerType.decoder]. Willdetermine the bias construction
num_attention_heads_kerple: Number of attention heads for which kerple bias will be used
max_seq_len: Maximum sequence length for precomputed relative positions. Larger sizes will result in more memory usage by computing kerple mask on-the-fly.
"""
super().__init__()

if (num_attention_heads_kerple is None) or (num_attention_heads_kerple <= 0):
num_attention_heads_kerple = num_attention_heads

if num_attention_heads_kerple > num_attention_heads:
raise ValueError(
f"num_attention_heads_kerple ({num_attention_heads_kerple}) cannot be larger than num_attention_heads ({num_attention_heads})"
)

self.bidirectional = bidirectional
self.num_attention_heads = num_attention_heads
# LayerType.encoder or LayerType.decoder. Is only needed to determine the group for the all_reduce
self.layer_type = layer_type
# define the size of pre-computed relative position slopes.
# define the number of attention heads for which kerple mask will be pre-computed (the rest are disabled).
self.num_attention_heads_kerple = num_attention_heads_kerple
# Larger sizes will result in more memory usage by computing kerple mask on-the-fly.
self.max_seq_len = max_seq_len

# initialize the slopes
self.kerple_b = torch.nn.Parameter(build_slopes(num_attention_heads, num_attention_heads_kerple))
self.kerple_a = torch.zeros_like(self.kerple_b)
self.kerple_p = torch.ones_like(self.kerple_b)

# cache the relative position bias. shape (num_attention_heads, max_seq_len, max_seq_len)
self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads)

def forward(self, query_seq_length, key_seq_length):
# used cached relative position if possible
max_seq_len = max(query_seq_length, key_seq_length)
if max_seq_len > self.max_seq_len:
relative_position = build_relative_position(max_seq_len, max_seq_len, self.num_attention_heads)
else:
relative_position = self.relative_position
# shape (num_attention_heads, query_seq_length, key_seq_length)
relative_position = relative_position[:, :query_seq_length, :key_seq_length]
# if not bidirectional, mask out the future positions
if not self.bidirectional:
relative_position = torch.tril(relative_position)

# shape (1, num_heads, query_length, key_length)
return -self.kerple_b * torch.log(1 + self.kerple_a * relative_position.unsqueeze(0).pow(self.kerple_p))
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import (
ALiBiRelativePositionEmbedding,
)
from nemo.collections.nlp.modules.common.megatron.kerple_relative_position_embedding import (
KERPLERelativePositionEmbedding,
)
from nemo.collections.nlp.modules.common.megatron.language_model import Embedding
from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType
from nemo.collections.nlp.modules.common.megatron.megatron_decoders import get_decoder_model
Expand Down Expand Up @@ -176,7 +179,16 @@ def __init__(
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding"
self._encoder_relative_position_embedding_key = "encoder_alibi_position_embedding"
elif self.encoder_cfg.get('position_embedding_type', 'learned_absolute') == 'kerple':
self.encoder_relative_position_embedding = KERPLERelativePositionEmbedding(
bidirectional=True,
num_attention_heads=encoder_cfg.num_attention_heads,
layer_type=LayerType.encoder,
num_attention_heads_kerple=None,
max_seq_len=max_position_embeddings,
)
self._encoder_relative_position_embedding_key = "encoder_kerple_position_embedding"
else:
self.encoder_relative_position_embedding = None

Expand Down Expand Up @@ -296,7 +308,16 @@ def __init__(
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding"
self._decoder_relative_position_embedding_key = "decoder_alibi_position_embedding"
elif self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'kerple':
self.decoder_relative_position_embedding = KERPLERelativePositionEmbedding(
bidirectional=False,
num_attention_heads=decoder_cfg.num_attention_heads,
layer_type=LayerType.decoder,
num_attention_heads_kerple=None,
max_seq_len=max_position_embeddings,
)
self._decoder_relative_position_embedding_key = "decoder_kerple_position_embedding"
else:
self.decoder_relative_position_embedding = None

Expand Down