-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][speed] Better projections - correctness + speed #119
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,17 +153,14 @@ def forward( | |
if self.in_proj_weight is not None: | ||
if id(query) == id(key): | ||
# Self attention, get all the projected values at once | ||
# we compute everything transposed, so that q,k,v stay contiguous after splitting | ||
# NOTE: the resulting buffers could be contiguous out of the box with a custom kernel | ||
qkv = query @ self.in_proj_weight.transpose(-2, -1) | ||
|
||
if self.in_proj_bias is not None: | ||
qkv += self.in_proj_bias | ||
|
||
q, k, v = map( | ||
lambda x: x.contiguous(), | ||
qkv.split(self.out_features, dim=-1), | ||
) | ||
return q, k, v | ||
qkv = qkv.split(self.out_features, -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was actually slow (memcopy) and not needed for non-sparse attention |
||
return qkv[0], qkv[1], qkv[2] | ||
|
||
else: | ||
# Not self attention | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,8 @@ | |
|
||
|
||
from enum import Enum | ||
from typing import List, Union | ||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from xformers import _is_triton_available | ||
|
@@ -16,14 +15,6 @@ | |
from xformers.triton.layer_norm import FusedLayerNorm | ||
|
||
|
||
def _to_tensor_list( | ||
inputs: Union[torch.Tensor, List[torch.Tensor]] | ||
) -> List[torch.Tensor]: | ||
if not isinstance(inputs, list): | ||
inputs = [inputs] | ||
return inputs | ||
|
||
|
||
class LayerNormStyle(str, Enum): | ||
"""Support different layer norm styles. | ||
See "On Layer Normalization in the Transformer Architecture", | ||
|
@@ -36,16 +27,23 @@ class LayerNormStyle(str, Enum): | |
|
||
# CREDITS: the following is inspired by FastAI's Transformer implementation | ||
class Residual(nn.Module): | ||
"""Object-oriented handling of the residual path""" | ||
"""Object-oriented handling of the residual path | ||
|
||
.. warning: by convention, if multiple tensors are being passed in, | ||
the first one is used for the residual path | ||
""" | ||
|
||
def __init__(self, layer: nn.Module): | ||
super().__init__() | ||
self.layer = layer | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was super error prone (@stephenroller spotted that long ago, I should have caught that then), since inputs and args can mix |
||
inputs = _to_tensor_list(inputs) | ||
|
||
return inputs[0] + self.layer(*inputs, *args, **kwargs) | ||
def forward( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
residual = args[0] | ||
return residual + self.layer(*args, **kwargs) | ||
|
||
|
||
class PreNorm(nn.Module): | ||
|
@@ -62,11 +60,16 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): | |
|
||
self.sublayer = sublayer | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
inputs = _to_tensor_list(inputs) | ||
|
||
x_norm = [self.norm(x_) for x_ in inputs] | ||
return self.sublayer(*x_norm, *args, **kwargs) | ||
def forward(self, *args, **kwargs): | ||
# Could be that the same tensor has been passed multiple times | ||
# in that case we'll just normalize once | ||
list_ids = [id(inp) for inp in args] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems small, but here if the same tensor was actually passed multiple times (self attention), we would normalize 3 times and loose the same id(), which in turn means that in the attention layer we would not optimize for self-attention |
||
if list_ids.count(list_ids[0]) == len(list_ids): | ||
normalized_input = self.norm(args[0]) | ||
sublayer_inputs = [normalized_input for _ in args] | ||
else: | ||
sublayer_inputs = [self.norm(x_) for x_ in args] | ||
return self.sublayer(*sublayer_inputs, **kwargs) | ||
|
||
|
||
class PostNorm(nn.Module): | ||
|
@@ -81,8 +84,6 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True): | |
|
||
self.sublayer = sublayer | ||
|
||
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs): | ||
inputs = _to_tensor_list(inputs) | ||
|
||
x = self.sublayer(*inputs, *args, **kwargs) | ||
def forward(self, *args, **kwargs): | ||
x = self.sublayer(*args, **kwargs) | ||
return self.norm(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -364,8 +364,8 @@ def forward( | |
else: | ||
target_q, target_k, target_v = target, target, target | ||
|
||
x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask) | ||
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask) | ||
x = self.wrap_att(target_q, target_k, target_v, att_mask=decoder_att_mask) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to the "input<>args" cleanup |
||
x = self.wrap_cross(x, memory, memory, att_mask=encoder_att_mask) | ||
x = self.wrap_ff(x) | ||
|
||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contiguous is only needed here and down, we were forcing this all the time before. Note that with a good projection kernel this could come for free, could be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you should leave a comment/breadcrumb trail saying that in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, or maybe just file an issue, basically the projection in the beginning of MHA could be hardened a little for speed