Skip to content

Commit

Permalink
Update _embedding.py RopeEmbeddings (#735)
Browse files Browse the repository at this point in the history
* Update _embedding.py RopeEmbeddings

It seems we forgot to make the `theta` variable public. This way, if you wanted to implement LLaMa3, you'd have a problem because they use a `theta` value of 50000.0.

* pch

* default args

* theta was actually used; fixed now
  • Loading branch information
Artur-Galstyan authored May 23, 2024
1 parent 2e5a25d commit 1f7908b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
11 changes: 8 additions & 3 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def process_heads(
"""

embedding_size: int = field(static=True)
theta: float = field(static=True, default=10_000.0)

def __check_init__(self):
if self.embedding_size < 0:
Expand All @@ -175,7 +176,7 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]):

@staticmethod
def precompute_freqs_cis(
embedding_size: int, end: int, theta: float = 10000.0
embedding_size: int, end: int, theta: float
) -> Complex[Array, "end half_of_embedding_size"]:
freqs = 1.0 / (
theta
Expand Down Expand Up @@ -220,12 +221,16 @@ def __call__(
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
else:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_real = jnp.tile(freqs_cis.real, (1, 2))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ def test_rope_embeddings_freqs_cis():
# values are generated using
# Metas Rope embedding code. See this gist which generates these
# expected values: https://gist.github.com/Artur-Galstyan/8d0bb5743f00650aa6c0d7427595a0ff
theta = 10_000.0
expected_freqs_cis = jnp.array(
[
[1.0000 + 0.0000j, 1.0000 + 0.0000j, 1.0000 + 0.0000j, 1.0000 + 0.0000j],
Expand All @@ -1381,7 +1382,7 @@ def test_rope_embeddings_freqs_cis():
embedding_size = 8
seq_length = 16
freqs_cis = eqx.nn.RotaryPositionalEmbedding.precompute_freqs_cis(
embedding_size, seq_length
embedding_size, seq_length, theta
)
assert jnp.allclose(freqs_cis, expected_freqs_cis, atol=1e-4)

Expand Down

0 comments on commit 1f7908b

Please sign in to comment.