-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][speed] Better projections - correctness + speed #119
Conversation
@@ -258,7 +258,7 @@ def matmul_with_mask(self, a, b): | |||
column_indices = self.column_indices | |||
out = _sddmm.apply( | |||
a, | |||
b.transpose(-2, -1), | |||
b.transpose(-2, -1).contiguous(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contiguous is only needed here and down, we were forcing this all the time before. Note that with a good projection kernel this could come for free, could be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you should leave a comment/breadcrumb trail saying that in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, or maybe just file an issue, basically the projection in the beginning of MHA could be hardened a little for speed
qkv.split(self.out_features, dim=-1), | ||
) | ||
return q, k, v | ||
qkv = qkv.split(self.out_features, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was actually slow (memcopy) and not needed for non-sparse attention
|
||
def __init__(self, layer: nn.Module): | ||
super().__init__() | ||
self.layer = layer | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was super error prone (@stephenroller spotted that long ago, I should have caught that then), since inputs and args can mix
def forward(self, *args, **kwargs): | ||
# Could be that the same tensor has been passed multiple times | ||
# in that case we'll just normalize once | ||
list_ids = [id(inp) for inp in args] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems small, but here if the same tensor was actually passed multiple times (self attention), we would normalize 3 times and loose the same id(), which in turn means that in the attention layer we would not optimize for self-attention
@@ -364,8 +364,8 @@ def forward( | |||
else: | |||
target_q, target_k, target_v = target, target, target | |||
|
|||
x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask) | |||
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask) | |||
x = self.wrap_att(target_q, target_k, target_v, att_mask=decoder_att_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related to the "input<>args" cleanup
@fmassa @dianaml0 I think that the code in block_factory and residual is probably way to complicated and error prone, could be worth another pass if deemed important (this is orthogonal to the parts zoo approach). In any case with these changes a default xformers is competitive in terms of speed vs. Timm ViT |
3f35d57
to
36d64f2
Compare
Codecov Report
@@ Coverage Diff @@
## vit_comp_bench #119 +/- ##
==================================================
- Coverage 87.22% 87.21% -0.02%
==================================================
Files 49 49
Lines 2490 2487 -3
==================================================
- Hits 2172 2169 -3
Misses 318 318
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Ping reviewers (and PR below), this significantly speeds up a code path used in the examples and by some folks around |
Seems like much of the complication was introduced for reversible layers to work? But agree probably a good idea to see if there's a cleaner way to do it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really great!! LGTM!
Yes for reversible, I agree, it makes the signal paths a lot more complex unfortunately:( no strong opinion on that, I don't think that this was benchmarked enough to have a definitive answer. Probably possible to have a better code and keep reversible though |
…, not attentions (facebookresearch#119) * follow up from facebookresearch#117, macro blocks mask inputs, not attentions * matching unit test
What does this PR do?
With these changes/fixes, vanilla xformers is 5-10% faster than timm on a vanilla ViT (up from being 10% slower) as per the bench from the other PR, and microGPT trains something decent in 20 minutes on a laptop (3080 / fp16)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.