-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Optimize MultiHeadAttention forward method #3550
Conversation
Avoid duplicated calculation of key-value pairs when incr_state is passed and state_kv is True
Hi @dejk-alg! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, awesome! Thanks! Do you think you could add some comments explaining this logic?
Tests are flaky but just rerunning to make sure this isn't catching on something wild.
heads up @EricMichaelSmith might help with latency on [internal thing] |
On a volta with bs 1. With blender3b: on cpu with bs1 with taskmaster 2 BART model: Real nice find! |
benchmark script i used: #!/usr/bin/env python3
import tqdm
from parlai.core.agents import create_agent_from_model_file
agent = create_agent_from_model_file('zoo:blender/blender_3B/model')
import timeit
start = timeit.default_timer()
for i in tqdm.tqdm(range(50)):
agent.observe(
{'text': 'Can you find me a mexican restaurant for 2?', 'episode_done': True}
)
print(agent.act()['text'])
agent.reset()
end = timeit.default_timer()
print(end - start) |
@stephenroller Stephen, thanks for having a look at this! Glad that I could participate here in some way. The logic behind this fix is pretty simple. Let me reiterate on what I said in the PR description a bit. The reason behind incremental state usage is obviously saving keys and values for attention to reuse them. When running inference of any encoder-decoder transformer, on every decoding step we need to get both encoder and self-attention keys and values for each decoder layer. Talking specifically about encoder_attention, general logic behind the module was totally fine - if it's our first decoding step and incr_state is still None, values are calculated from encoder_state tensors, otherwise we can just reuse the ones we calculated before. But there was a bug in implementation that forced MultiHeadAttention to recalculate these values despite replacing them with completely similar values saved in incr_state afterwards. So it went like "Let's get processed encoder tensors - Oh, nevermind, we'd already had them". It's quite an obvious thing, but it doesn't strike the eye much unless you specifically check the model logic to look for the ways to optimize latency |
So yeah, this change is probably more of a "hotfix" than "optimization" by definition |
If you are interested, I've spent a bit of time with blender model, and while the setting I work with is pretty different from what you work with here in the lib, I think I could add a couple of optimization changes to transformer modules. But all of them would generally be longer in terms of code and not that large in terms of latency decrease. So I'd love to know whether you think it's worth it, given that these changes will have to be reviewed by you or your team anyway |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great find - thanks for changing!
I'm definitely interested. Lowering latency would be nice for sure. We do have maintainability and backwards compatibility to worry about, so I can't promise they would all be accepted. We can also discuss your proposals before you implement them. Feel free to file an issue. |
Avoid duplicated calculation of key-value pairs when incr_state is passed and state_kv is True
Patch description
Slightly alters MultiHeadAttention forward method.
The problem with the original code is that when incr_state dict is passed to model with state_kv set to True, the method calculates k and v (processed pre-attention key and value tensors) from key / value inputs by passing them through the respective linear layers, despite replacing them with similar tensors extracted from incr_state just a few lines of code later.
PR addresses this by a small change that avoid this excessive calculation by moving it to conditional branches where this is not the case and it is indeed required.
PR should not be breaking for any ParlAI modules as it does not alter the logic of MultiHeadAttention or its outputs in any way. The only thing that actually changes is that inner redundant operation is no longer performed when not needed.
While doing proper benchmarking to showcase the difference PR brings would be pretty complex, since the module performance differs widely among the different settings and models it's used as a part of, I've seen latency decrease as much as twice for a decoding iteration of the Blender model (both for cpu / gpu settings and 90M / 3B models). While in general average performance gain is probably lower, it still seems really significant to me.