From 371abe08b5308dea5a59d67da5907d6e3cffd712 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Thu, 29 Aug 2024 00:08:30 -0700 Subject: [PATCH 1/3] Fixing a bug in the way "attention_factor" is validated in ROPE utilities. --- src/transformers/modeling_rope_utils.py | 9 +++++---- tests/utils/test_modeling_rope_utils.py | 9 +++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c09664d688c3b1..5788238b58b3db 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -487,10 +487,11 @@ def _validate_longrope_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) def _validate_llama3_parameters(config: PretrainedConfig): diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index cfc648a71d2ecb..437c68014723a2 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -330,6 +330,15 @@ def test_longrope_rope_numerically(self): _, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1) self.assertEqual(attention_scale, 0.5) + config.rope_scaling = { + "rope_type": "longrope", + "factor": factor, + "short_factor": short_factor, + "long_factor": long_factor, + } + self.assertEqual(config.rope_scaling.get("attention_factor"), None) + rope_config_validation(config) + # Check 2: Factor == 1.0 -> short factor is applied to the default frequencies factor = 1.0 config.rope_scaling = { From 087e800256786f75d7392f947c17dfb5fac2dc3b Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Thu, 29 Aug 2024 00:10:50 -0700 Subject: [PATCH 2/3] Fixing a bug in the way "attention_factor" is validated in ROPE utilities. --- tests/utils/test_modeling_rope_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 437c68014723a2..b25794d7c39b6e 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -337,6 +337,7 @@ def test_longrope_rope_numerically(self): "long_factor": long_factor, } self.assertEqual(config.rope_scaling.get("attention_factor"), None) + # Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'"" is not raised. rope_config_validation(config) # Check 2: Factor == 1.0 -> short factor is applied to the default frequencies From 1c41e8b2a6b6d88a25b12a51cc98c13ee10d0a27 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Thu, 29 Aug 2024 00:22:48 -0700 Subject: [PATCH 3/3] Fixing a bug in the way "attention_factor" is validated in ROPE utilities. --- tests/utils/test_modeling_rope_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index b25794d7c39b6e..a1d1fd6b922ab3 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -337,7 +337,7 @@ def test_longrope_rope_numerically(self): "long_factor": long_factor, } self.assertEqual(config.rope_scaling.get("attention_factor"), None) - # Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'"" is not raised. + # Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised. rope_config_validation(config) # Check 2: Factor == 1.0 -> short factor is applied to the default frequencies