Skip to content

Commit

Permalink
Bugfix/alexsherstinsky/fix none check for attention factor in rope sc…
Browse files Browse the repository at this point in the history
…aling 2024 08 28 0 (huggingface#33188)

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.
  • Loading branch information
alexsherstinsky authored and BernardZach committed Dec 6, 2024
1 parent c9def3b commit 166d86c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,11 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")

attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if attention_factor is not None:
if not isinstance(attention_factor, float) or attention_factor < 0.0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)


def _validate_llama3_parameters(config: PretrainedConfig):
Expand Down
10 changes: 10 additions & 0 deletions tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,16 @@ def test_longrope_rope_numerically(self):
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)

config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
self.assertEqual(config.rope_scaling.get("attention_factor"), None)
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
rope_config_validation(config)

# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
factor = 1.0
config.rope_scaling = {
Expand Down

0 comments on commit 166d86c

Please sign in to comment.