diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 9d145de98499..e13eb1dea643 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, @@ -241,23 +242,47 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None): super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -265,6 +290,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -468,9 +498,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding( - dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta - ) + self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 05821973b147..edfdc94346bf 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -41,7 +41,7 @@ logging, ) from ...utils.import_utils import is_triton_available -from ..gemma.modeling_gemma import apply_rotary_pos_emb +from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb if is_flash_attn_2_available(): @@ -504,32 +504,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.Wo(self.drop(self.act(input) * gate)) -class ModernBertRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): + def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None): + super().__init__(self, config=config, device=device) + inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) def eager_attention_forward( @@ -698,9 +676,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding( - dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta - ) + self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 08e77505b5b7..9f286cf3985f 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -16,8 +16,9 @@ import unittest import pytest +from packaging import version -from transformers import ModernBertConfig, is_torch_available +from transformers import AutoTokenizer, ModernBertConfig, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( CaptureLogger, @@ -362,6 +363,131 @@ def test_flash_attn_2_conversion(self): @require_torch class ModernBertModelIntegrationTest(unittest.TestCase): - """ - These still need to be written, once public models are available. - """ + @slow + def test_inference_masked_lm(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + model = ModernBertForMaskedLM.from_pretrained( + "answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa" + ) + tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + + inputs = tokenizer("Hello World!", return_tensors="pt") + with torch.no_grad(): + output = model(**inputs)[0] + expected_shape = torch.Size((1, 5, 50368)) + self.assertEqual(output.shape, expected_shape) + + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[3.8387, -0.2017, 12.2839], [3.6300, 0.6869, 14.7123], [-5.1137, -3.8122, 11.9874]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_no_head(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + model = ModernBertModel.from_pretrained( + "answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa" + ) + tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + + inputs = tokenizer("Hello World!", return_tensors="pt") + with torch.no_grad(): + output = model(**inputs)[0] + expected_shape = torch.Size((1, 5, 768)) + self.assertEqual(output.shape, expected_shape) + + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[0.3151, -0.6417, -0.7027], [-0.7834, -1.5810, 0.4576], [1.0614, -0.7268, -0.0871]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_token_classification(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + model = ModernBertForTokenClassification.from_pretrained( + "hf-internal-testing/tiny-random-ModernBertForTokenClassification", + reference_compile=False, + attn_implementation="sdpa", + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-ModernBertForTokenClassification") + + inputs = tokenizer("Hello World!", return_tensors="pt") + with torch.no_grad(): + output = model(**inputs)[0] + expected_shape = torch.Size((1, 5, 2)) + self.assertEqual(output.shape, expected_shape) + + expected = torch.tensor( + [[[2.0159, 4.6569], [-0.9430, 3.1595], [-3.8770, 3.2653], [1.5752, 4.5167], [-1.6939, 1.2524]]] + ) + self.assertTrue(torch.allclose(output, expected, atol=1e-4)) + + @slow + def test_inference_sequence_classification(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + model = ModernBertForSequenceClassification.from_pretrained( + "hf-internal-testing/tiny-random-ModernBertForSequenceClassification", + reference_compile=False, + attn_implementation="sdpa", + ) + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-ModernBertForSequenceClassification" + ) + + inputs = tokenizer("Hello World!", return_tensors="pt") + with torch.no_grad(): + output = model(**inputs)[0] + expected_shape = torch.Size((1, 2)) + self.assertEqual(output.shape, expected_shape) + + expected = torch.tensor([[1.6466, 4.5662]]) + self.assertTrue(torch.allclose(output, expected, atol=1e-4)) + + @slow + def test_export(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + bert_model = "answerdotai/ModernBERT-base" + device = "cpu" + attn_implementation = "sdpa" + max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained(bert_model) + inputs = tokenizer( + "the man worked as a [MASK].", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = ModernBertForMaskedLM.from_pretrained( + bert_model, + device_map=device, + attn_implementation=attn_implementation, + ) + + logits = model(**inputs).logits + eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask.split(), ["lawyer", "mechanic", "teacher", "doctor", "waiter"]) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask, ep_predicted_mask)