Skip to content

Commit

Permalink
Llama/GPTNeoX: add RoPE scaling (#24653)
Browse files Browse the repository at this point in the history
* add rope_scaling

* tmp commit

* add gptneox

* add tests

* GPTNeoX can now handle long inputs, so the pipeline test was wrong

* Update src/transformers/models/open_llama/configuration_open_llama.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove ntk

* remove redundant validation

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
gante and amyeroberts authored Jul 13, 2023
1 parent 9342c8f commit 34d9409
Show file tree
Hide file tree
Showing 11 changed files with 484 additions and 67 deletions.
35 changes: 35 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig):
use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B).
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Example:
```python
Expand Down Expand Up @@ -115,6 +124,7 @@ def __init__(
eos_token_id=2,
tie_word_embeddings=False,
use_parallel_residual=True,
rope_scaling=None,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -135,7 +145,32 @@ def __init__(
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
126 changes: 102 additions & 24 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _set_gradient_checkpointing(self, module, value=False):
class GPTNeoXAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
Expand All @@ -94,28 +95,56 @@ def __init__(self, config):
)
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self._init_bias(config.max_position_embeddings)

self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self._init_rope()

self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)

self.attention_dropout = nn.Dropout(config.attention_dropout)

def _init_bias(self, max_positions, device=None):
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
if device is not None:
self.bias = self.bias.to(device)

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = GPTNeoXRotaryEmbedding(
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -210,6 +239,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)

# dynamically increase the causal mask with the key length, if needed.
if key_length > self.bias.shape[-1]:
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
Expand Down Expand Up @@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores


class RotaryEmbedding(torch.nn.Module):
class GPTNeoXRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
Expand All @@ -275,18 +315,56 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None):

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)


class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]


class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len

if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,24 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights


# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
Expand All @@ -256,15 +264,8 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None):

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)


Expand Down
34 changes: 34 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Example:
```python
Expand Down Expand Up @@ -97,6 +106,7 @@ def __init__(
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -109,10 +119,34 @@ def __init__(
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
Loading

0 comments on commit 34d9409

Please sign in to comment.