diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c09664d688c3b1..5788238b58b3db 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -487,10 +487,11 @@ def _validate_longrope_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) def _validate_llama3_parameters(config: PretrainedConfig): diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index cfc648a71d2ecb..a1d1fd6b922ab3 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -330,6 +330,16 @@ def test_longrope_rope_numerically(self): _, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1) self.assertEqual(attention_scale, 0.5) + config.rope_scaling = { + "rope_type": "longrope", + "factor": factor, + "short_factor": short_factor, + "long_factor": long_factor, + } + self.assertEqual(config.rope_scaling.get("attention_factor"), None) + # Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised. + rope_config_validation(config) + # Check 2: Factor == 1.0 -> short factor is applied to the default frequencies factor = 1.0 config.rope_scaling = {