Skip to content

Commit

Permalink
Set RotaryEmbedding base frequency from config (NVIDIA#7734)
Browse files Browse the repository at this point in the history
* set rope base frequency from model config

Signed-off-by: Shantanu Acharya <shantanua@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename base to rotary_base

Signed-off-by: Shantanu Acharya <shantanua@nvidia.com>

---------

Signed-off-by: Shantanu Acharya <shantanua@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
2 people authored and pzelasko committed Jan 3, 2024
1 parent 2dfc004 commit 52cfa4d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
super(GPTModel, self).__init__(config=config, share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def model_provider_func(self, pre_process, post_process):
use_flash_attention=self.cfg.get('use_flash_attention', False),
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_language_model(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -202,6 +203,7 @@ def get_language_model(
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -502,6 +504,7 @@ def __init__(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
super(TransformerLanguageModel, self).__init__(
config=config, share_token_embeddings=share_embeddings_and_output_weights
Expand Down Expand Up @@ -557,6 +560,7 @@ def __init__(
rotary_dim,
seq_len_interpolation_factor=seq_len_interpolation_factor,
pretrained_max_position_embeddings=max_position_embeddings,
rotary_base=rotary_base,
)

elif position_embedding_type == 'alibi':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,25 @@ class RotaryEmbedding(nn.Module):
"""

def __init__(
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
self,
dim: int,
seq_len_interpolation_factor: int = None,
rotary_base: int = 10000,
pretrained_max_position_embeddings: int = None,
):
"""
Args:
dim (int): rotary embedding dimension
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
by this factor via the trick in https://arxiv.org/abs/2306.15595.
by this factor via the trick in https://arxiv.org/abs/2306.15595.
rotary_base (int): rotary_base for the positional frequency (default: 10000)
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.rotary_base = rotary_base
inv_freq = 1.0 / (self.rotary_base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

Expand Down

0 comments on commit 52cfa4d

Please sign in to comment.