Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch.repeat instead of expand on key & value in Triton MQA to prevent NaNs with certain h_dims #442

Merged
merged 3 commits into from
Jul 8, 2023

Conversation

sashaDoubov
Copy link
Contributor

@sashaDoubov sashaDoubov commented Jul 7, 2023

For h_dim=8, we see NaNs due to a .expand of the key value tensors, which can be resolved with .repeat. While h_dim=8 is an edge case, we are not sure if there are other cases of h_dims for which this might be problematic, or if there might be silent failures.

Note that this does come at a performance hit, see MFU for 7b, so it may be desirable to revert this change in the future.

(blue curve: .expand() green curve .repeat() red curve .expand().clone())
image

@sashaDoubov sashaDoubov merged commit 86a99e2 into mosaicml:main Jul 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants