Skip to content

Commit

Permalink
[refactor] Moving local attention to sparse backend (facebookresearch#50
Browse files Browse the repository at this point in the history
)

* moving local attention to sparse backend
* better handling of the causal/window sizes implications
* some cleaning up
  • Loading branch information
blefaudeux authored Apr 20, 2021
1 parent 19fe7e8 commit 5e10833
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 189 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def instantiate_xformer(
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"window_size": sequence_length // 8,
"window_size": sequence_length // 8 + 1,
"from_seq_dim": sequence_length,
}

Expand Down
Binary file modified docs/plots/memory_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/runtime_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ def test_order_invariance(
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"window_size": SEQ // 4,
"from_seq_dim": SEQ,
"causal": causal,
"window_size": SEQ // 8 + 1,
}

attention = build_attention(AttentionConfig(**test_config))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_xformer_encoder_block(
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"window_size": SEQ // 8,
"window_size": SEQ // 8 + 1,
"from_seq_dim": SEQ,
}

Expand Down Expand Up @@ -101,7 +101,7 @@ def test_xformer_decoder_block(
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"window_size": SEQ // 8,
"window_size": SEQ // 8 + 1,
"from_seq_dim": SEQ,
}

Expand Down
254 changes: 70 additions & 184 deletions xformers/components/attention/local.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,93 @@
# see https://arxiv.org/pdf/2003.05997.pdf
# and
# FIXME: proper credits

import math
from dataclasses import dataclass
from functools import reduce
from operator import mul
from typing import Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
import torch.nn as nn

from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.positional_encoding.relative_positional import (
RelativePositionalEncoding,
from xformers.components.attention import (
_SPARSITY_THRESHOLD,
Attention,
AttentionConfig,
register_attention,
)

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work


def default(value, d):
return d if value is None else value


def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max


def merge_dims(ind_from, ind_to, tensor):
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)


def expand_dim(t, dim, k, unsqueeze=True):
if unsqueeze:
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)


def pad_to_multiple(tensor, multiple, dim=-1, value=0):
seqlen = tensor.shape[dim]
m = seqlen / multiple
if m.is_integer():
return tensor
remainder = math.ceil(m) * multiple - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value=value)


def look_around(x, backward=1, forward=0, pad_value=-1, dim=2):
t = x.shape[1]
dims = (len(x.shape) - dim) * (0, 0)
padded_x = F.pad(x, (*dims, backward, forward), value=pad_value)
tensors = [
padded_x[:, ind : (ind + t), ...] for ind in range(forward + backward + 1)
]
return torch.cat(tensors, dim=dim)
from xformers.components.attention.core import scaled_dot_product_attention


@dataclass(init=False)
class LocalAttentionConfig(AttentionConfig):
causal: bool
window_size: int
autopad: Optional[bool]
shared_qk: Optional[bool]
exact_window_size: Optional[bool]
look_backward: Optional[int]
look_forward: Optional[int]
rel_pos_emb_config: Optional[Tuple[int, int]]


@register_attention("local")
class LocalAttention(Attention):
r"""
An implementation of a sliding window attention, as proposed in LongFormers
https://arxiv.org/pdf/2004.05150.pdf
# Credits : https://github.com/lucidrains/local-attention
"""

def __init__(
self,
dropout: float,
causal: bool,
window_size: int,
look_backward: int = 1,
look_forward: int = 0,
shared_qk: bool = False,
rel_pos_emb_config: Optional[Tuple[int, int]] = None,
autopad: bool = False,
exact_window_size: bool = False,
dropout: float = 0.0,
causal: bool = False,
window_size: int = 5,
*args,
**kwargs,
):

r"""
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
Args:
dropout (float): the probability of an output to be randomly dropped at training time
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
window_size (int): the overall window size for local attention.
Odd number is expected if the mask is not causal, as the window size will be evenly
distributed on both sides of each query
_RoutingTransformer: "Efficient Content-Based Sparse Attention with Routing Transformers", A. Roy et al.
https://arxiv.org/pdf/2003.05997.pdf
_BigBird: "Big Bird: Transformers for Longer Sequences" M. Zaheer et al
https://arxiv.org/pdf/2007.14062.pdf
_Longformer: "Longformer: The Long-Document Transformer.", I. Beltagy et al
https://arxiv.org/pdf/2004.05150.pdf
"""
super().__init__()
look_forward = default(look_forward, 0 if causal else 1)
assert not (causal and look_forward > 0), "you cannot look forward if causal"

self.attn_drop = nn.Dropout(dropout, inplace=True)
self.causal = causal

if not self.causal:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"

self.window_size = window_size
self.look_backward = look_backward
self.look_forward = look_forward
self.exact_window_size = exact_window_size
self.autopad = autopad
self.mask: Optional[torch.Tensor] = None

def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
if self.causal:
mask = torch.tril(torch.ones(shape[1], shape[1])).to(dtype=torch.bool)
mask &= ~torch.tril(
torch.ones(shape[1], shape[1]), diagonal=-self.window_size - 1
).to(dtype=torch.bool)
else:
h_win_size = self.window_size // 2
mask = torch.tril(torch.ones(shape[1], shape[1]), diagonal=h_win_size).to(
dtype=torch.bool
)
mask &= ~torch.tril(
torch.ones(shape[1], shape[1]), diagonal=-(h_win_size + 1)
).to(dtype=torch.bool)

self.dropout = nn.Dropout(dropout)
# Take the batch dimension into account
# FIXME: not needed with https://github.com/fairinternal/xformers/issues/42
mask = mask.expand(shape[0], shape[1], shape[1])

self.shared_qk = shared_qk
# Sparsify if that makes sense
if torch.count_nonzero(mask).item() / mask.numel() < _SPARSITY_THRESHOLD:
mask = mask.to_sparse()

self.rel_pos = None
if rel_pos_emb_config is not None:
dim_head, heads = rel_pos_emb_config
rel_pos_length = window_size * (1 + look_forward + look_backward)
self.heads = heads
self.rel_pos = RelativePositionalEncoding(dim_head, rel_pos_length, heads)
return mask

def forward(
self,
Expand All @@ -128,98 +97,15 @@ def forward(
att_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:

shape = q.shape

if self.autopad:
# FIXME: This is probably broken
orig_t = q.shape[1]
q, k, v = map(
lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
)

B, S, E = q.size() # batch x sequence x embedding
device, dtype = q.device, q.dtype
assert (
S % self.window_size
) == 0, f"sequence length {S} must be divisible by window size {self.window_size} for local attention"

windows = S // self.window_size

if self.shared_qk:
k = F.normalize(k, 2, dim=-1).type_as(q)

ticker = torch.arange(S, device=device, dtype=dtype)[None, :]
b_t = ticker.reshape(1, windows, self.window_size)

bq, bk, bv = map(lambda t: t.reshape(B, windows, self.window_size, -1), (q, k, v)) # type: ignore

look_around_kwargs = {
"backward": self.look_backward,
"forward": self.look_forward,
}
bk = look_around(bk, **look_around_kwargs)
bv = look_around(bv, **look_around_kwargs)

bq_t = b_t
bq_k = look_around(b_t, **look_around_kwargs)

dots = torch.einsum("bhie,bhje->bhij", bq, bk) * (E ** -0.5)

if self.rel_pos is not None:
rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(
dots
)
dots = dots + rel_attn

mask_value = max_neg_value(dots)
):
# Local window attention masking
if self.mask is None or self.mask.shape[1] != q.shape[1]:
self.mask = self._get_local_mask(q.shape).to(q.device)

if self.shared_qk:
mask = bq_t[:, :, :, None] == bq_k[:, :, None, :]
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
del mask
# Take into account the optional user mask
mask = self.mask if att_mask is None else self.mask & att_mask

if self.causal:
mask = bq_t[:, :, :, None] < bq_k[:, :, None, :]

if self.exact_window_size:
max_causal_window_size = self.window_size * self.look_backward
mask = mask | (
bq_t[:, :, :, None] > (bq_k[:, :, None, :] + max_causal_window_size)
)

dots.masked_fill_(mask, mask_value)
del mask

mask = bq_k[:, :, None, :] == -1
dots.masked_fill_(mask, mask_value)
del mask

if att_mask is not None:
pass
# FIXME @lefaudeux
# h = B // att_mask.shape[0]
# if self.autopad:
# att_mask = pad_to_multiple(att_mask, self.window_size, dim=-1, value=False)
# att_mask = att_mask.reshape(-1, windows, self.window_size) # type: ignore # Mypy is drunk
# mq = mk = att_mask
# mk = look_around(mk, pad_value=False, **look_around_kwargs)
# mask = mq[:, :, :, None] * mk[:, :, None, :]
# mask = merge_dims(0, 1, expand_dim(mask, 1, h))
# dots.masked_fill_(~mask, mask_value)
# del mask

attn = dots.softmax(dim=-1)
attn = self.dropout(attn)

out = torch.einsum("bhij,bhje->bhie", attn, bv)
out = out.reshape(-1, S, E)

if self.autopad:
out = out[:, :orig_t, :]

return out.reshape(*shape)
return scaled_dot_product_attention(q, k, v, mask, dropout=self.attn_drop)

@classmethod
def from_config(cls, config: AttentionConfig) -> "Attention":
Expand Down

0 comments on commit 5e10833

Please sign in to comment.