diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index fa7ea2af816cb0..17b3dc42cf5684 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -45,6 +45,11 @@ FalconForTokenClassification, FalconModel, ) + from transformers.models.falcon.modeling_falcon import ( + FalconDynamicNTKScalingRotaryEmbedding, + FalconLinearScalingRotaryEmbedding, + FalconRotaryEmbedding, + ) class FalconModelTester: @@ -408,7 +413,8 @@ def test_past_key_values_format(self): ) @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -438,6 +444,65 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = FalconRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = FalconLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 19e3db2a61fb91..92d130b35101bb 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -38,6 +38,11 @@ GPTNeoXForTokenClassification, GPTNeoXModel, ) + from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXDynamicNTKScalingRotaryEmbedding, + GPTNeoXLinearScalingRotaryEmbedding, + GPTNeoXRotaryEmbedding, + ) class GPTNeoXModelTester: @@ -301,7 +306,8 @@ def test_feed_forward_chunking(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -331,6 +337,66 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = GPTNeoXRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 36dc8d6bcdf4e8..e0a3990bd8de30 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -51,6 +51,11 @@ LlamaModel, LlamaTokenizer, ) + from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaRotaryEmbedding, + ) class LlamaModelTester: @@ -370,7 +375,7 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -400,6 +405,69 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = LlamaRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = LlamaLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 864db992772772..79cee8a64863cb 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -45,6 +45,11 @@ PersimmonForSequenceClassification, PersimmonModel, ) + from transformers.models.persimmon.modeling_persimmon import ( + PersimmonDynamicNTKScalingRotaryEmbedding, + PersimmonLinearScalingRotaryEmbedding, + PersimmonRotaryEmbedding, + ) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon @@ -365,8 +370,8 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon - def test_model_rope_scaling(self, scaling_type): + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -396,6 +401,66 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = PersimmonRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index d69bbb32c1a682..e3c145bfa268ca 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -19,8 +19,9 @@ import unittest import pytest +from parameterized import parameterized -from transformers import PhiConfig, is_torch_available +from transformers import PhiConfig, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -46,6 +47,11 @@ PhiForTokenClassification, PhiModel, ) + from transformers.models.phi.modeling_phi import ( + PhiDynamicNTKScalingRotaryEmbedding, + PhiLinearScalingRotaryEmbedding, + PhiRotaryEmbedding, + ) class PhiModelTester: @@ -360,6 +366,98 @@ def test_phi_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = PhiModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = PhiModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = PhiRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = PhiLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 2497dfc3eee6c4..64f828825c44fa 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -44,6 +44,11 @@ StableLmForSequenceClassification, StableLmModel, ) + from transformers.models.stablelm.modeling_stablelm import ( + StableLmDynamicNTKScalingRotaryEmbedding, + StableLmLinearScalingRotaryEmbedding, + StableLmRotaryEmbedding, + ) # Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm @@ -351,7 +356,8 @@ def test_stablelm_sequence_classification_model_for_multi_label(self): self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -381,6 +387,66 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = StableLmRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = StableLmLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class StableLmModelIntegrationTest(unittest.TestCase):