Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Optimize MultiHeadAttention forward method #3550

Merged
merged 1 commit into from
Mar 26, 2021

Conversation

dejk-alg
Copy link
Contributor

@dejk-alg dejk-alg commented Mar 24, 2021

Avoid duplicated calculation of key-value pairs when incr_state is passed and state_kv is True

Patch description
Slightly alters MultiHeadAttention forward method.

The problem with the original code is that when incr_state dict is passed to model with state_kv set to True, the method calculates k and v (processed pre-attention key and value tensors) from key / value inputs by passing them through the respective linear layers, despite replacing them with similar tensors extracted from incr_state just a few lines of code later.

PR addresses this by a small change that avoid this excessive calculation by moving it to conditional branches where this is not the case and it is indeed required.

PR should not be breaking for any ParlAI modules as it does not alter the logic of MultiHeadAttention or its outputs in any way. The only thing that actually changes is that inner redundant operation is no longer performed when not needed.

While doing proper benchmarking to showcase the difference PR brings would be pretty complex, since the module performance differs widely among the different settings and models it's used as a part of, I've seen latency decrease as much as twice for a decoding iteration of the Blender model (both for cpu / gpu settings and 90M / 3B models). While in general average performance gain is probably lower, it still seems really significant to me.

Avoid duplicated calculation of key-value pairs when incr_state is passed and state_kv is True
@facebook-github-bot
Copy link

Hi @dejk-alg!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, awesome! Thanks! Do you think you could add some comments explaining this logic?

Tests are flaky but just rerunning to make sure this isn't catching on something wild.

@stephenroller
Copy link
Contributor

heads up @EricMichaelSmith might help with latency on [internal thing]

@stephenroller
Copy link
Contributor

On a volta with bs 1.
With a taskmaster2 BART model:
master: 50 generations in 31.4s
this patch: 50 generations in 29.0s, or about 7% speedup

With blender3b:
master: 50 generations in 74s
this patch: 50 generations in 67s, or about a 9% speedup.

on cpu with bs1 with taskmaster 2 BART model:
master: 50 generations in 59.0s
this patch: 50 generations in 54.2, or about 8% speedup

Real nice find!

@stephenroller
Copy link
Contributor

benchmark script i used:

#!/usr/bin/env python3
import tqdm

from parlai.core.agents import create_agent_from_model_file

agent = create_agent_from_model_file('zoo:blender/blender_3B/model')

import timeit

start = timeit.default_timer()
for i in tqdm.tqdm(range(50)):
    agent.observe(
        {'text': 'Can you find me a mexican restaurant for 2?', 'episode_done': True}
    )
    print(agent.act()['text'])
    agent.reset()
end = timeit.default_timer()

print(end - start)

@dejk-alg
Copy link
Contributor Author

dejk-alg commented Mar 26, 2021

@stephenroller Stephen, thanks for having a look at this! Glad that I could participate here in some way.

The logic behind this fix is pretty simple. Let me reiterate on what I said in the PR description a bit.

The reason behind incremental state usage is obviously saving keys and values for attention to reuse them. When running inference of any encoder-decoder transformer, on every decoding step we need to get both encoder and self-attention keys and values for each decoder layer.

Talking specifically about encoder_attention, general logic behind the module was totally fine - if it's our first decoding step and incr_state is still None, values are calculated from encoder_state tensors, otherwise we can just reuse the ones we calculated before. But there was a bug in implementation that forced MultiHeadAttention to recalculate these values despite replacing them with completely similar values saved in incr_state afterwards.

So it went like "Let's get processed encoder tensors - Oh, nevermind, we'd already had them".

It's quite an obvious thing, but it doesn't strike the eye much unless you specifically check the model logic to look for the ways to optimize latency

@dejk-alg
Copy link
Contributor Author

So yeah, this change is probably more of a "hotfix" than "optimization" by definition

@dejk-alg
Copy link
Contributor Author

dejk-alg commented Mar 26, 2021

If you are interested, I've spent a bit of time with blender model, and while the setting I work with is pretty different from what you work with here in the lib, I think I could add a couple of optimization changes to transformer modules. But all of them would generally be longer in terms of code and not that large in terms of latency decrease. So I'd love to know whether you think it's worth it, given that these changes will have to be reviewed by you or your team anyway

Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find - thanks for changing!

@stephenroller
Copy link
Contributor

If you are interested, I've spent a bit of time with blender model, and while the setting I work with is pretty different from what you work with here in the lib, I think I could add a couple of optimization changes to transformer modules. But all of them would generally be longer in terms of code and not that large in terms of latency decrease. So I'd love to know whether you think it's worth it, given that these changes will have to be reviewed by you or your team anyway

I'm definitely interested. Lowering latency would be nice for sure. We do have maintainability and backwards compatibility to worry about, so I can't promise they would all be accepted. We can also discuss your proposals before you implement them. Feel free to file an issue.

@stephenroller stephenroller merged commit bbd6e26 into facebookresearch:master Mar 26, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants