Skip to content

Commit

Permalink
fix neo
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 2, 2024
1 parent 6883fdc commit 3d7131b
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,26 @@
from litgpt.model import apply_rope, build_rope_cache


# @torch.inference_mode()
# def test_rope_gptneox():
# bs, seq_len, n_head, n_embed = 1, 6, 2, 8
# head_size = n_embed // n_head
# x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
# position_ids = torch.arange(seq_len).unsqueeze(0)
@torch.inference_mode()
def test_rope_gptneox():
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
position_ids = torch.arange(seq_len).unsqueeze(0)

# theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len)
# theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)
theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, max_position_embeddings=seq_len)
theirs_cos, theirs_sin = theirs_rot_emb(x, seq_len=seq_len)

# ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
# # their rope cache has 2 added dimensions and the cos/sin is duplicated
# torch.testing.assert_close(ours_cos_cached, theirs_cos.squeeze())
# torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze())
ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
# their rope cache has 2 added dimensions and the cos/sin is duplicated
torch.testing.assert_close(ours_cos_cached, theirs_cos.squeeze())
torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze())

# ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
# theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)
# torch.testing.assert_close(ours_x_rope, theirs_x_rope)
ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)
torch.testing.assert_close(ours_x_rope, theirs_x_rope)


# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_2():
head_dim = 64
Expand Down

0 comments on commit 3d7131b

Please sign in to comment.