-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement NTK-Aware scaled and dynamically scaled RoPE for PositionRotaryEmbedding #529
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import torch | ||
import torch.distributed | ||
|
||
|
@@ -41,6 +42,12 @@ | |
TensorParallelHead, | ||
) | ||
|
||
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) | ||
|
||
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": | ||
ROPE_DYNAMIC_SCALING = True | ||
else: | ||
ROPE_DYNAMIC_SCALING = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nothing should be model specific. |
||
|
||
class LlamaRMSNorm(nn.Module): | ||
def __init__(self, prefix, weights, eps=1e-6): | ||
|
@@ -105,10 +112,18 @@ def __init__( | |
self.num_heads = config.num_attention_heads | ||
self.hidden_size = config.hidden_size | ||
self.head_size = self.hidden_size // self.num_heads | ||
self.scale_factor = ROPE_SCALE_FACTOR | ||
self.dynamic_scaling = ROPE_DYNAMIC_SCALING | ||
|
||
self.rotary_emb = PositionRotaryEmbedding.load( | ||
prefix=f"{prefix}.rotary_emb", weights=weights | ||
) | ||
if self.scale_factor > 1 or self.dynamic_scaling: | ||
# Base before scaling is 10000 per the original RoPE paper | ||
self.rotary_emb = PositionRotaryEmbedding.static( | ||
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling | ||
) | ||
else: | ||
self.rotary_emb = PositionRotaryEmbedding.load( | ||
prefix=f"{prefix}.rotary_emb", weights=weights | ||
) | ||
|
||
self.softmax_scale = self.head_size**-0.5 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,7 +369,7 @@ def forward(self, hidden_states, residual=None): | |
import rotary_emb | ||
|
||
class PositionRotaryEmbedding(nn.Module): | ||
def __init__(self, inv_freq): | ||
def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have at most 1 extra argument. A lot of information should be extractable directly from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, I can try to simplify this |
||
super().__init__() | ||
|
||
self.inv_freq = inv_freq | ||
|
@@ -379,32 +379,61 @@ def __init__(self, inv_freq): | |
self._cos_k_cached = None | ||
self._sin_k_cached = None | ||
|
||
@classmethod | ||
def static(cls, dim, base, device): | ||
inv_freq = 1.0 / ( | ||
base | ||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) | ||
) | ||
return cls(inv_freq) | ||
self.scale_factor = scale_factor | ||
self.dynamic_scaling = dynamic_scaling | ||
self.original_max_seq_len = max_seq_len | ||
self.max_seq_len = max_seq_len * scale_factor | ||
self.dim = dim | ||
self.base = base | ||
|
||
@classmethod | ||
def static(cls, dim, base, device, scale_factor=1, dynamic_scaling=False, max_seq_len=2048): | ||
inv_freq = cls._get_inv_freq(dim, base, device, scale_factor) | ||
return cls(inv_freq, scale_factor, dynamic_scaling, max_seq_len, dim, base) | ||
|
||
@classmethod | ||
def load(cls, prefix, weights): | ||
# XXX: Always load this in float32 ! | ||
dtype = weights.dtype | ||
weights.dtype = torch.float32 | ||
|
||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq") | ||
weights.dtype = dtype | ||
return cls(inv_freq) | ||
|
||
@staticmethod | ||
def _get_inv_freq(dim, base, device, scale_factor=1): | ||
base = base * scale_factor ** (dim / (dim-2)) | ||
|
||
inv_freq = 1.0 / ( | ||
base | ||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) | ||
) | ||
|
||
return inv_freq | ||
|
||
def _update_cos_sin_cache(self, dtype, device, seqlen): | ||
# Reset the tables if the sequence length has changed, | ||
# or if we're on a new device (possibly due to tracing for instance) | ||
|
||
length = seqlen | ||
max_seq_len = self.max_seq_len | ||
inv_freq = self.inv_freq | ||
|
||
if self.dynamic_scaling: | ||
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1) | ||
max_seq_len = self.original_max_seq_len * scale_factor | ||
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not really OK I think. You ditching entirely the original Llama most notably has different saved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Part of dynamic scaling is calculating the new inv_freq, looking at the dynamic scaling implementation in Transformers I don't see them preserving this value either. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would you suggest alternatively? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking interpolation when I wrote this. Now that I reflect more it would make the code even more complex, which is not the desired effect. Can we maybe move out the scaling factor out of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And so let's keep rewriting inv_freq. It has some indesirable effects on those models, but the other way is even worse. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds reasonable, I'll make this change after work. |
||
|
||
if self.scale_factor > 1 and not self.dynamic_scaling: | ||
length = max(seqlen, max_seq_len) | ||
|
||
if ( | ||
seqlen > self._seq_len_cached | ||
length > self._seq_len_cached | ||
or self._cos_cached.device != device | ||
or self._cos_cached.dtype != dtype | ||
): | ||
self._seq_len_cached = seqlen | ||
self._seq_len_cached = length | ||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) | ||
# Don't do einsum, it converts fp32 to fp16 | ||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not belong in this PR.
We can discuss changing the defaults, but it's a separate concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yup fair, I don't want to change them I meant to clean this out. I'll remove!