Skip to content

Commit ad3d157

Browse files
authored
[RoPE] abstract dynamic RoPE update under a decorator ✨ (#37249)
* dynamic rope decorator * longrope; shorter fwd pass * propper docstring * make fixup
1 parent 3d40bda commit ad3d157

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+527
-1833
lines changed

docs/source/en/internal/modeling_utils.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Most of those are only useful if you are studying the code of the models in the
2525
[[autodoc]] AttentionInterface
2626
- register
2727

28+
## Rotary Position Embedding Functions
29+
30+
[[autodoc]] dynamic_rope_update
31+
2832
## Pytorch custom modules
2933

3034
[[autodoc]] pytorch_utils.Conv1D

src/transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@
14831483

14841484
_import_structure["modeling_flash_attention_utils"] = []
14851485
_import_structure["modeling_outputs"] = []
1486-
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
1486+
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
14871487
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
14881488

14891489
# PyTorch models structure
@@ -6762,7 +6762,7 @@
67626762
model_addition_debugger,
67636763
model_addition_debugger_context,
67646764
)
6765-
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
6765+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
67666766
from .modeling_utils import AttentionInterface, PreTrainedModel
67676767
from .models.albert import (
67686768
AlbertForMaskedLM,

src/transformers/modeling_rope_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16+
from functools import wraps
1617
from typing import Optional
1718

1819
from .configuration_utils import PretrainedConfig
@@ -26,6 +27,68 @@
2627
import torch
2728

2829

30+
def dynamic_rope_update(rope_forward):
31+
"""
32+
Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
33+
(i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
34+
35+
Args:
36+
rope_forward (Callable):
37+
The forward pass of the RoPE implementation.
38+
39+
Returns:
40+
The decorated forward pass.
41+
"""
42+
43+
def longrope_frequency_update(self, position_ids, device):
44+
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
45+
seq_len = torch.max(position_ids) + 1
46+
if hasattr(self.config, "original_max_position_embeddings"):
47+
original_max_position_embeddings = self.config.original_max_position_embeddings
48+
else:
49+
original_max_position_embeddings = self.config.max_position_embeddings
50+
if seq_len > original_max_position_embeddings:
51+
if not hasattr(self, "long_inv_freq"):
52+
self.long_inv_freq, _ = self.rope_init_fn(
53+
self.config, device, seq_len=original_max_position_embeddings + 1
54+
)
55+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
56+
else:
57+
# This .to() is needed if the model has been moved to a device after being initialized (because
58+
# the buffer is automatically moved, but not the original copy)
59+
self.original_inv_freq = self.original_inv_freq.to(device)
60+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
61+
62+
def dynamic_frequency_update(self, position_ids, device):
63+
"""
64+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
65+
1 - growing beyond the cached sequence length (allow scaling)
66+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
67+
"""
68+
seq_len = torch.max(position_ids) + 1
69+
if seq_len > self.max_seq_len_cached: # growth
70+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
71+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
72+
self.max_seq_len_cached = seq_len
73+
74+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
75+
# This .to() is needed if the model has been moved to a device after being initialized (because
76+
# the buffer is automatically moved, but not the original copy)
77+
self.original_inv_freq = self.original_inv_freq.to(device)
78+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
79+
self.max_seq_len_cached = self.original_max_seq_len
80+
81+
@wraps(rope_forward)
82+
def wrapper(self, x, position_ids):
83+
if "dynamic" in self.rope_type:
84+
dynamic_frequency_update(self, position_ids, device=x.device)
85+
elif self.rope_type == "longrope":
86+
longrope_frequency_update(self, position_ids, device=x.device)
87+
return rope_forward(self, x, position_ids)
88+
89+
return wrapper
90+
91+
2992
def _compute_default_rope_parameters(
3093
config: Optional[PretrainedConfig] = None,
3194
device: Optional["torch.device"] = None,

src/transformers/models/aria/modeling_aria.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...modeling_attn_mask_utils import AttentionMaskConverter
2929
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3030
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
31-
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
31+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
3232
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3333
from ...processing_utils import Unpack
3434
from ...utils import (
@@ -752,47 +752,18 @@ def __init__(self, config: AriaTextConfig, device=None):
752752
self.register_buffer("inv_freq", inv_freq, persistent=False)
753753
self.original_inv_freq = self.inv_freq
754754

755-
def _dynamic_frequency_update(self, position_ids, device):
756-
"""
757-
dynamic RoPE layers should recompute `inv_freq` in the following situations:
758-
1 - growing beyond the cached sequence length (allow scaling)
759-
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
760-
"""
761-
seq_len = torch.max(position_ids) + 1
762-
if seq_len > self.max_seq_len_cached: # growth
763-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
764-
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
765-
self.max_seq_len_cached = seq_len
766-
767-
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
768-
# This .to() is needed if the model has been moved to a device after being initialized (because
769-
# the buffer is automatically moved, but not the original copy)
770-
self.original_inv_freq = self.original_inv_freq.to(device)
771-
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
772-
self.max_seq_len_cached = self.original_max_seq_len
773-
774755
@torch.no_grad()
756+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
775757
def forward(self, x, position_ids):
776-
if "dynamic" in self.rope_type:
777-
self._dynamic_frequency_update(position_ids, device=x.device)
778-
779-
# Core RoPE block
780-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
758+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
781759
position_ids_expanded = position_ids[:, None, :].float()
782-
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
783-
device_type = x.device.type
784-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
785-
with torch.autocast(device_type=device_type, enabled=False):
786-
freqs = (
787-
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
788-
).transpose(1, 2)
789-
emb = torch.cat((freqs, freqs), dim=-1)
790-
cos = emb.cos()
791-
sin = emb.sin()
792760

793-
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
794-
cos = cos * self.attention_scaling
795-
sin = sin * self.attention_scaling
761+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
762+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
763+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
764+
emb = torch.cat((freqs, freqs), dim=-1)
765+
cos = emb.cos() * self.attention_scaling
766+
sin = emb.sin() * self.attention_scaling
796767

797768
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
798769

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from ...modeling_attn_mask_utils import AttentionMaskConverter
3838
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3939
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
40-
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
40+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4141
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4242
from ...processing_utils import Unpack
4343
from ...utils import (
@@ -142,47 +142,18 @@ def __init__(self, config: BambaConfig, device=None):
142142
self.register_buffer("inv_freq", inv_freq, persistent=False)
143143
self.original_inv_freq = self.inv_freq
144144

145-
def _dynamic_frequency_update(self, position_ids, device):
146-
"""
147-
dynamic RoPE layers should recompute `inv_freq` in the following situations:
148-
1 - growing beyond the cached sequence length (allow scaling)
149-
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
150-
"""
151-
seq_len = torch.max(position_ids) + 1
152-
if seq_len > self.max_seq_len_cached: # growth
153-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
154-
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
155-
self.max_seq_len_cached = seq_len
156-
157-
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
158-
# This .to() is needed if the model has been moved to a device after being initialized (because
159-
# the buffer is automatically moved, but not the original copy)
160-
self.original_inv_freq = self.original_inv_freq.to(device)
161-
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
162-
self.max_seq_len_cached = self.original_max_seq_len
163-
164145
@torch.no_grad()
146+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
165147
def forward(self, x, position_ids):
166-
if "dynamic" in self.rope_type:
167-
self._dynamic_frequency_update(position_ids, device=x.device)
168-
169-
# Core RoPE block
170-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
148+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
171149
position_ids_expanded = position_ids[:, None, :].float()
172-
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
173-
device_type = x.device.type
174-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
175-
with torch.autocast(device_type=device_type, enabled=False):
176-
freqs = (
177-
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
178-
).transpose(1, 2)
179-
emb = torch.cat((freqs, freqs), dim=-1)
180-
cos = emb.cos()
181-
sin = emb.sin()
182150

183-
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
184-
cos = cos * self.attention_scaling
185-
sin = sin * self.attention_scaling
151+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
152+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
153+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
154+
emb = torch.cat((freqs, freqs), dim=-1)
155+
cos = emb.cos() * self.attention_scaling
156+
sin = emb.sin() * self.attention_scaling
186157

187158
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
188159

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ...modeling_attn_mask_utils import AttentionMaskConverter
4040
from ...modeling_flash_attention_utils import FlashAttentionKwargs
4141
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
42-
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
42+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4343
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4444
from ...processing_utils import Unpack
4545
from ...utils import (
@@ -101,45 +101,18 @@ def __init__(self, config: CohereConfig, device=None):
101101
self.register_buffer("inv_freq", inv_freq, persistent=False)
102102
self.original_inv_freq = self.inv_freq
103103

104-
def _dynamic_frequency_update(self, position_ids, device):
105-
"""
106-
dynamic RoPE layers should recompute `inv_freq` in the following situations:
107-
1 - growing beyond the cached sequence length (allow scaling)
108-
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
109-
"""
110-
seq_len = torch.max(position_ids) + 1
111-
if seq_len > self.max_seq_len_cached: # growth
112-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
113-
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
114-
self.max_seq_len_cached = seq_len
115-
116-
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
117-
# This .to() is needed if the model has been moved to a device after being initialized (because
118-
# the buffer is automatically moved, but not the original copy)
119-
self.original_inv_freq = self.original_inv_freq.to(device)
120-
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
121-
self.max_seq_len_cached = self.original_max_seq_len
122-
123104
@torch.no_grad()
105+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
124106
def forward(self, x, position_ids):
125-
if "dynamic" in self.rope_type:
126-
self._dynamic_frequency_update(position_ids, device=x.device)
127-
128-
# Core RoPE block
129107
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
130108
position_ids_expanded = position_ids[:, None, :].float()
131-
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
132-
device_type = x.device.type
133-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
134-
with torch.autocast(device_type=device_type, enabled=False):
109+
110+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
111+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
135112
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136113
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
137-
cos = emb.cos()
138-
sin = emb.sin()
139-
140-
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
141-
cos = cos * self.attention_scaling
142-
sin = sin * self.attention_scaling
114+
cos = emb.cos() * self.attention_scaling
115+
sin = emb.sin() * self.attention_scaling
143116

144117
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
145118

src/transformers/models/cohere/modular_cohere.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ...cache_utils import Cache
3232
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3333
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34+
from ...modeling_rope_utils import dynamic_rope_update
3435
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
3536
from ...processing_utils import Unpack
3637
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@@ -73,25 +74,17 @@ def forward(self, hidden_states):
7374

7475
class CohereRotaryEmbedding(LlamaRotaryEmbedding):
7576
@torch.no_grad()
77+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
7678
def forward(self, x, position_ids):
77-
if "dynamic" in self.rope_type:
78-
self._dynamic_frequency_update(position_ids, device=x.device)
79-
80-
# Core RoPE block
8179
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
8280
position_ids_expanded = position_ids[:, None, :].float()
83-
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
84-
device_type = x.device.type
85-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
86-
with torch.autocast(device_type=device_type, enabled=False):
81+
82+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
83+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
8784
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
8885
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
89-
cos = emb.cos()
90-
sin = emb.sin()
91-
92-
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
93-
cos = cos * self.attention_scaling
94-
sin = sin * self.attention_scaling
86+
cos = emb.cos() * self.attention_scaling
87+
sin = emb.sin() * self.attention_scaling
9588

9689
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
9790

0 commit comments

Comments
 (0)