|
28 | 28 | from ...modeling_attn_mask_utils import AttentionMaskConverter |
29 | 29 | from ...modeling_flash_attention_utils import FlashAttentionKwargs |
30 | 30 | from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
31 | | -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| 31 | +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
32 | 32 | from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
33 | 33 | from ...processing_utils import Unpack |
34 | 34 | from ...utils import ( |
@@ -752,47 +752,18 @@ def __init__(self, config: AriaTextConfig, device=None): |
752 | 752 | self.register_buffer("inv_freq", inv_freq, persistent=False) |
753 | 753 | self.original_inv_freq = self.inv_freq |
754 | 754 |
|
755 | | - def _dynamic_frequency_update(self, position_ids, device): |
756 | | - """ |
757 | | - dynamic RoPE layers should recompute `inv_freq` in the following situations: |
758 | | - 1 - growing beyond the cached sequence length (allow scaling) |
759 | | - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
760 | | - """ |
761 | | - seq_len = torch.max(position_ids) + 1 |
762 | | - if seq_len > self.max_seq_len_cached: # growth |
763 | | - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) |
764 | | - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation |
765 | | - self.max_seq_len_cached = seq_len |
766 | | - |
767 | | - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset |
768 | | - # This .to() is needed if the model has been moved to a device after being initialized (because |
769 | | - # the buffer is automatically moved, but not the original copy) |
770 | | - self.original_inv_freq = self.original_inv_freq.to(device) |
771 | | - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
772 | | - self.max_seq_len_cached = self.original_max_seq_len |
773 | | - |
774 | 755 | @torch.no_grad() |
| 756 | + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) |
775 | 757 | def forward(self, x, position_ids): |
776 | | - if "dynamic" in self.rope_type: |
777 | | - self._dynamic_frequency_update(position_ids, device=x.device) |
778 | | - |
779 | | - # Core RoPE block |
780 | | - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| 758 | + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
781 | 759 | position_ids_expanded = position_ids[:, None, :].float() |
782 | | - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) |
783 | | - device_type = x.device.type |
784 | | - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
785 | | - with torch.autocast(device_type=device_type, enabled=False): |
786 | | - freqs = ( |
787 | | - inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() |
788 | | - ).transpose(1, 2) |
789 | | - emb = torch.cat((freqs, freqs), dim=-1) |
790 | | - cos = emb.cos() |
791 | | - sin = emb.sin() |
792 | 760 |
|
793 | | - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention |
794 | | - cos = cos * self.attention_scaling |
795 | | - sin = sin * self.attention_scaling |
| 761 | + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| 762 | + with torch.autocast(device_type=device_type, enabled=False): # Force float32 |
| 763 | + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| 764 | + emb = torch.cat((freqs, freqs), dim=-1) |
| 765 | + cos = emb.cos() * self.attention_scaling |
| 766 | + sin = emb.sin() * self.attention_scaling |
796 | 767 |
|
797 | 768 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
798 | 769 |
|
|
0 commit comments