Skip to content

Commit

Permalink
mv bias subset
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Mar 17, 2023
1 parent 8cbe513 commit 5fcfb6e
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 55 deletions.
5 changes: 2 additions & 3 deletions examples/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ComposerHFPrefixLM, ComposerHFT5)
from examples.llm.src.models.layers.attention import (
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape,
flash_attn_fn, generate_attn_bias,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
Expand All @@ -40,7 +40,6 @@
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
4 changes: 1 addition & 3 deletions examples/llm/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
ComposerHFT5)
from examples.llm.src.models.layers.attention import (
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape, flash_attn_fn,
generate_attn_bias, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
Expand All @@ -28,7 +27,6 @@
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
4 changes: 1 addition & 3 deletions examples/llm/src/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from examples.llm.src.models.layers.attention import (
MultiheadAttention, alibi_bias, attn_bias_, attn_bias_shape, flash_attn_fn,
generate_attn_bias, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock

__all__ = [
Expand All @@ -14,7 +13,6 @@
'MultiheadAttention',
'attn_bias_shape',
'attn_bias_',
'generate_attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
35 changes: 14 additions & 21 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,16 @@ def triton_flash_attn_fn(
raise NotImplementedError(
f'attn_impl: triton cannot return attn weights.')

if key_padding_mask is not None and key_padding_mask.bool().logical_not(
).any():
raise NotImplementedError(
f'assumes key_padding_mask is taken care of by attn_bias')
if key_padding_mask is not None:
b_size, s_k = key_padding_mask.shape

if attn_bias is not None:
attn_bias = attn_bias.expand(b_size, -1, -1, -1)
else:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)

attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)), -float('inf'))

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
Expand Down Expand Up @@ -274,6 +280,10 @@ def forward(self,

# attention
query, key, value = qkv.chunk(3, dim=2)

if attn_bias is not None:
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]

context, attn_weights = self.attn_fn(
query,
key,
Expand Down Expand Up @@ -342,23 +352,6 @@ def attn_bias_(attn_impl,
raise ValueError(f'{attn_impl=} is an invalid setting.')


def generate_attn_bias(attn_impl,
attn_bias,
seq_len,
batch_size,
key_padding_mask=None):
if attn_bias is not None:
# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]

if attn_impl == 'triton' and key_padding_mask is not None:
attn_bias = attn_bias.expand(batch_size, -1, -1, -1)
attn_bias.masked_fill(
~key_padding_mask.view((batch_size, 1, 1, seq_len)), -float('inf'))

return attn_bias


def alibi_bias(n_heads,
seq_len,
full=False,
Expand Down
21 changes: 4 additions & 17 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,7 @@ def __init__(self, cfg: DictConfig):
if cfg.get('verbose') and cfg.get('verbose') > 2:
print(self)

def _attn_bias(self,
batch_size=None,
seq_len=None,
key_padding_mask=None,
device=None,
dtype=None):
def _attn_bias(self, device, dtype):
if not self._attn_bias_initialized:
if self.attn_bias_shape:
self.attn_bias = torch.empty(self.attn_bias_shape,
Expand All @@ -128,16 +123,12 @@ def _attn_bias(self,
alibi_bias_max=self.alibi_bias_max)
self._attn_bias_initialized = True

return attention.generate_attn_bias(self.attn_impl,
self.attn_bias,
seq_len,
batch_size,
key_padding_mask=key_padding_mask)
return self.attn_bias

def forward(self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None):
B, S = input_ids.size()
S = input_ids.size(1)
assert (
S <= self.cfg.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.cfg.max_seq_len}'
Expand All @@ -160,11 +151,7 @@ def forward(self,
assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
x = self.transformer.emb_drop(x_shrunk)

attn_bias = self._attn_bias(batch_size=B,
seq_len=S,
key_padding_mask=key_padding_mask,
device=x.device,
dtype=x.dtype)
attn_bias = self._attn_bias(device=x.device, dtype=x.dtype)
if self.cfg.attn_impl == 'flash' and key_padding_mask is None:
# HazyResearch FlashMHA appears to use more memory when `key_padding_mask=None`
# in certain settings like MosaicGPT-7B. So we always provide a tensor.
Expand Down
5 changes: 0 additions & 5 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def gen_bias(attn_impl, key_padding_mask):
s,
alibi=alibi,
alibi_bias_max=8)
attn_bias = attention.generate_attn_bias(
attn_impl, attn_bias, s, n, key_padding_mask=key_padding_mask)

if attn_impl == 'triton':
return attn_bias, key_padding_mask

return attn_bias, key_padding_mask

Expand Down
4 changes: 1 addition & 3 deletions examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def test_attention_mechanism(batch_size=2):
axis=1)
expected_zerod_weights |= torch_key_padding

attn_bias = model.model._attn_bias(batch_size=batch_size,
seq_len=test_cfg.max_seq_len,
key_padding_mask=key_padding_mask)
attn_bias = model.model._attn_bias(device=x.device, dtype=x.dtype)

for block in model.model.transformer.blocks:
a = block.ln_1(x)
Expand Down

0 comments on commit 5fcfb6e

Please sign in to comment.