Skip to content

Commit

Permalink
Llama 3.1: Fix incorrect inv_freq assignment (#32330)
Browse files Browse the repository at this point in the history
fix 💩
  • Loading branch information
gante authored Jul 31, 2024
1 parent 7f552e2 commit b75ad56
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,14 @@ def _compute_llama3_parameters(
wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

return inv_freq, attention_factor
return inv_freq_llama, attention_factor


# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
Expand Down
30 changes: 29 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from packaging import version
from parameterized import parameterized

from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
Expand Down Expand Up @@ -718,6 +718,34 @@ def setUpClass(cls):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

@slow
@require_read_token
def test_llama_3_1_hard(self):
"""
An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences
from llama 3.1.'s RoPE can be detected
"""
EXPECTED_TEXT = (
"Tell me about the french revolution. The french revolution was a period of radical social and political "
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
"First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative "
"assembly that had not met since 1614. The Third Estate, which represented the common people, "
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
)
input_text = ["Tell me about the french revolution."]
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(generated_text, EXPECTED_TEXT)

@slow
@require_read_token
def test_model_7b_logits_bf16(self):
Expand Down

0 comments on commit b75ad56

Please sign in to comment.