-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: RoPE embeddings do not preserve dtypes #836
Conversation
Btw, another issue: the order of parameters is annoyingly the opposite of the order in Though I suppose that there is about one person who were using RoPE anyways (since transformers in JAX are unusable due to the lack of decent flash attention), so maybe it is not critical. Idk. WDYT? |
@Artur-Galstyan Do you mind to also take a look? It was your implementation, after all. Would appreciate it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit aside this LGTM! Thank you for the fix..
(Side note, I'm not sure what the minimum useful precision for RoPE is? Given the high-frequency behaviour, I'm wondering if we should expect to always compute part of this in higher precision, e.g. like how softmax should use float32 at minimum.)
Actually that's a very good point! I also wanted to discuss this, but totally forgot lol. In general, indeed, I don't think it makes sense to use low-precision RoPE: the quality degrades quite a bit, and applying RoPE is not really that compute-heavy to care about it. However, giving the user control over what is the internal type feels important. But yeah, I would expect the default way to use RoPE to look like rope = eqx.RotaryPositionalEmbedding(42, dtype=jnp.float32)
query = jnp.array([ ... ], dtype=jnp.bfloat16)
query = rope(query).astype(jnp.bfloat16)
... |
Alright, so I looked at my last comment, and I figured that I can just straight up convert the output to the old dtype, similarly to how it is done in normalization modules. So now we avoid manual conversion. |
@knyazer it LGTM too. The reason I implemented this using complex numbers was because it was actually slightly faster (though not by much if I remember correctly - there should be somewhere a speed comparison table in one of the closed PRs). FYI, there is currently a subtle bug in the way it's currently implemented. I have a PR open #799 to fix that, but it's not 100% done yet. It'd be nice to get this merged so that I can rebase on it in my own PR. |
Thanks for the feedback, @Artur-Galstyan! Sorry about your PR, I didn't check beforehand, so you will probably have a ton of conflicts 😅 About the speed: yea, I would have expected complex to be slightly faster for a bunch of reasons, but I'm pretty sure using a smaller dtype leads to a more significant performance gain. |
@patrick-kidger PTAL |
Awesome! Thanks for the fix, this LGTM. :) |
* wip: fixing RoPE not preserving dtypes * feat: implement dtype support for RoPE and improve docs * fix: docos * fix: remove a forgotten astype * refactor: naming convention + better control flow * fix: actually assign the dtype of the input in the end * fix: address the review
RoPE embeddings use complex as an underlying representation for the frequencies, which leads to it enforcing
float32
dtype when applying to any float type that is lower in the promotion chain thanfloat32
(e.g.float16
/bfloat16
). This PR fixes the issue by adding a new keyword argument,dtype
, and using the frequencies with the correctdtype
throughout the implementation.