-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable flag to not pass PAD tokens in ffwd #775
Conversation
Can you run main vs your branch with use_pad_tok_in_ffwd flag vs your branch without use_pad_tok_in_ffwd flag? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test that tests numerical equivalence of computation with and without the flag? might be off by a bit because of numerics, but lets see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nits
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
This PR does two things:
attn_bias
function to always return the attention_mask..forward
on the ffwd network then re-add in the pad tokens.Loss curves on a fully randomly initialized network:
We also get slightly higher throughput from this when there are PAD tokens in our dataset (and no degradation when compared to
main
withattn_impl: triton
:wandb: https://wandb.ai/mosaic-ml/padding_check?workspace=user-bcui