Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: RoPE embeddings do not preserve dtypes #836

Merged
merged 7 commits into from
Sep 10, 2024

Conversation

knyazer
Copy link
Contributor

@knyazer knyazer commented Sep 7, 2024

RoPE embeddings use complex as an underlying representation for the frequencies, which leads to it enforcing float32 dtype when applying to any float type that is lower in the promotion chain than float32 (e.g. float16/bfloat16). This PR fixes the issue by adding a new keyword argument, dtype, and using the frequencies with the correct dtype throughout the implementation.

@knyazer
Copy link
Contributor Author

knyazer commented Sep 7, 2024

Btw, another issue: the order of parameters is annoyingly the opposite of the order in nn.Embedding, but we cannot change it now because backwards compatibility 😭

Though I suppose that there is about one person who were using RoPE anyways (since transformers in JAX are unusable due to the lack of decent flash attention), so maybe it is not critical. Idk. WDYT?

equinox/nn/_embedding.py Outdated Show resolved Hide resolved
@knyazer knyazer marked this pull request as ready for review September 7, 2024 19:14
@knyazer
Copy link
Contributor Author

knyazer commented Sep 7, 2024

@Artur-Galstyan Do you mind to also take a look? It was your implementation, after all. Would appreciate it!

@knyazer knyazer changed the title wip: fixing RoPE not preserving dtypes fix: RoPE embeddings do not preserve dtypes Sep 7, 2024
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit aside this LGTM! Thank you for the fix..

(Side note, I'm not sure what the minimum useful precision for RoPE is? Given the high-frequency behaviour, I'm wondering if we should expect to always compute part of this in higher precision, e.g. like how softmax should use float32 at minimum.)

equinox/nn/_embedding.py Outdated Show resolved Hide resolved
@knyazer
Copy link
Contributor Author

knyazer commented Sep 8, 2024

Given the high-frequency behaviour, I'm wondering if we should expect to always compute part of this in higher precision, e.g. like how softmax should use float32 at minimum.

Actually that's a very good point! I also wanted to discuss this, but totally forgot lol.

In general, indeed, I don't think it makes sense to use low-precision RoPE: the quality degrades quite a bit, and applying RoPE is not really that compute-heavy to care about it. However, giving the user control over what is the internal type feels important.

But yeah, I would expect the default way to use RoPE to look like

rope = eqx.RotaryPositionalEmbedding(42, dtype=jnp.float32)
query = jnp.array([ ... ], dtype=jnp.bfloat16)
query = rope(query).astype(jnp.bfloat16) 
...

@knyazer
Copy link
Contributor Author

knyazer commented Sep 8, 2024

Alright, so I looked at my last comment, and I figured that I can just straight up convert the output to the old dtype, similarly to how it is done in normalization modules. So now we avoid manual conversion.

@Artur-Galstyan
Copy link
Contributor

@knyazer it LGTM too. The reason I implemented this using complex numbers was because it was actually slightly faster (though not by much if I remember correctly - there should be somewhere a speed comparison table in one of the closed PRs).

FYI, there is currently a subtle bug in the way it's currently implemented. I have a PR open #799 to fix that, but it's not 100% done yet. It'd be nice to get this merged so that I can rebase on it in my own PR.

@knyazer
Copy link
Contributor Author

knyazer commented Sep 8, 2024

Thanks for the feedback, @Artur-Galstyan! Sorry about your PR, I didn't check beforehand, so you will probably have a ton of conflicts 😅

About the speed: yea, I would have expected complex to be slightly faster for a bunch of reasons, but I'm pretty sure using a smaller dtype leads to a more significant performance gain.

equinox/nn/_embedding.py Outdated Show resolved Hide resolved
equinox/nn/_embedding.py Show resolved Hide resolved
equinox/nn/_embedding.py Outdated Show resolved Hide resolved
@knyazer
Copy link
Contributor Author

knyazer commented Sep 10, 2024

@patrick-kidger PTAL

@patrick-kidger patrick-kidger changed the base branch from main to dev September 10, 2024 17:09
@patrick-kidger patrick-kidger merged commit 9dc93e2 into patrick-kidger:dev Sep 10, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Awesome! Thanks for the fix, this LGTM. :)

patrick-kidger pushed a commit that referenced this pull request Sep 14, 2024
* wip: fixing RoPE not preserving dtypes

* feat: implement dtype support for RoPE and improve docs

* fix: docos

* fix: remove a forgotten astype

* refactor: naming convention + better control flow

* fix: actually assign the dtype of the input in the end

* fix: address the review
@patrick-kidger patrick-kidger mentioned this pull request Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants