Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds softmax_scale to flash attention #209

Merged
merged 1 commit into from
Mar 2, 2023

Conversation

codestar12
Copy link
Contributor

extends the option to scale softmax to flash attention in addition to triton flash attention

Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Hopefully you resolve the GPU count issue described offline

Copy link
Contributor

@bcui19 bcui19 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this!

@codestar12 codestar12 merged commit 0a800ab into mosaicml:main Mar 2, 2023
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads

if self.attn_qk_ln or self.clip_qkv:
if self.attn_qk_ln or self.clip_qkv or self.softmax_scale:
self.W_qkv = nn.Linear(self.d_model,
3 * self.d_model,
bias=True,
device=device)
self.inner_attn = FlashAttention(attention_dropout=cfg.attn_pdrop,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we can now scale the attention by any custom value...what is the best value and how should we keep track of it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1/srqt(d/n_heads) is standard.
1/(d/n_heads) is the recommended muP mod.
Since its part of the model config, it'll get dumped into wandb if we ever need to check what we used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants