Skip to content

Commit

Permalink
fixing the self attention optimization, broken in the factory
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Nov 20, 2021

Verified

This commit was signed with the committer’s verified signature.
itegulov Daniyar Itegulov
1 parent 9499a7a commit 3f35d57
Showing 6 changed files with 35 additions and 37 deletions.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
@@ -271,7 +271,7 @@ def top_k_logits(logits, k):
if __name__ == "__main__":
seed_everything(42)
REF_BATCH = 512
BATCH = 256 # adjust depending on the avaiable memory on your machine
BATCH = 512 # adjust depending on the avaiable memory on your machine
WORKERS = 8
EPOCHS = 1
BLOCK = 128
@@ -299,7 +299,7 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="nystrom",
attention="scaled_dot_product",
warmup_tokens=REF_BATCH * WARMUP,
learning_rate=LR,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
2 changes: 1 addition & 1 deletion xformers/components/attention/_sputnik_sparse.py
Original file line number Diff line number Diff line change
@@ -258,7 +258,7 @@ def matmul_with_mask(self, a, b):
column_indices = self.column_indices
out = _sddmm.apply(
a,
b.transpose(-2, -1),
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
4 changes: 2 additions & 2 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
@@ -143,9 +143,9 @@ def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
if _is_sparse_available:
if isinstance(a, SparseCS):
return a.spmm(b)
return a.spmm(b.contiguous())
if a.is_sparse:
return _sparse_bmm(a, b)
return _sparse_bmm(a, b.contiguous())
return a @ b


9 changes: 3 additions & 6 deletions xformers/components/in_proj_container.py
Original file line number Diff line number Diff line change
@@ -153,17 +153,14 @@ def forward(
if self.in_proj_weight is not None:
if id(query) == id(key):
# Self attention, get all the projected values at once
# we compute everything transposed, so that q,k,v stay contiguous after splitting
# NOTE: the resulting buffers could be contiguous out of the box with a custom kernel
qkv = query @ self.in_proj_weight.transpose(-2, -1)

if self.in_proj_bias is not None:
qkv += self.in_proj_bias

q, k, v = map(
lambda x: x.contiguous(),
qkv.split(self.out_features, dim=-1),
)
return q, k, v
qkv = qkv.split(self.out_features, -1)
return qkv[0], qkv[1], qkv[2]

else:
# Not self attention
49 changes: 25 additions & 24 deletions xformers/components/residual.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,8 @@


from enum import Enum
from typing import List, Union
from typing import Union

import torch
import torch.nn as nn

from xformers import _is_triton_available
@@ -16,14 +15,6 @@
from xformers.triton.layer_norm import FusedLayerNorm


def _to_tensor_list(
inputs: Union[torch.Tensor, List[torch.Tensor]]
) -> List[torch.Tensor]:
if not isinstance(inputs, list):
inputs = [inputs]
return inputs


class LayerNormStyle(str, Enum):
"""Support different layer norm styles.
See "On Layer Normalization in the Transformer Architecture",
@@ -36,16 +27,23 @@ class LayerNormStyle(str, Enum):

# CREDITS: the following is inspired by FastAI's Transformer implementation
class Residual(nn.Module):
"""Object-oriented handling of the residual path"""
"""Object-oriented handling of the residual path
.. warning: by convention, if multiple tensors are being passed in,
the first one is used for the residual path
"""

def __init__(self, layer: nn.Module):
super().__init__()
self.layer = layer

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)

return inputs[0] + self.layer(*inputs, *args, **kwargs)
def forward(
self,
*args,
**kwargs,
):
residual = args[0]
return residual + self.layer(*args, **kwargs)


class PreNorm(nn.Module):
@@ -62,11 +60,16 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True):

self.sublayer = sublayer

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)

x_norm = [self.norm(x_) for x_ in inputs]
return self.sublayer(*x_norm, *args, **kwargs)
def forward(self, *args, **kwargs):
# Could be that the same tensor has been passed multiple times
# in that case we'll just normalize once
list_ids = [id(inp) for inp in args]
if list_ids.count(list_ids[0]) == len(list_ids):
normalized_input = self.norm(args[0])
sublayer_inputs = [normalized_input for _ in args]
else:
sublayer_inputs = [self.norm(x_) for x_ in args]
return self.sublayer(*sublayer_inputs, **kwargs)


class PostNorm(nn.Module):
@@ -81,8 +84,6 @@ def __init__(self, d_model: int, sublayer: nn.Module, use_triton: bool = True):

self.sublayer = sublayer

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
inputs = _to_tensor_list(inputs)

x = self.sublayer(*inputs, *args, **kwargs)
def forward(self, *args, **kwargs):
x = self.sublayer(*args, **kwargs)
return self.norm(x)
4 changes: 2 additions & 2 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
@@ -364,8 +364,8 @@ def forward(
else:
target_q, target_k, target_v = target, target, target

x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask)
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_att(target_q, target_k, target_v, att_mask=decoder_att_mask)
x = self.wrap_cross(x, memory, memory, att_mask=encoder_att_mask)
x = self.wrap_ff(x)

return x

0 comments on commit 3f35d57

Please sign in to comment.