Skip to content

Commit

Permalink
Enable flag to not pass PAD tokens in ffwd (#775)
Browse files Browse the repository at this point in the history
* Changing how attention_mask is being passed around

* adding in option to toggle flags for padding

* moving to flash attn import

* removing unused import

* Removing excess stuff from pyproject

* refactor

* Update llmfoundry/models/layers/blocks.py

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* Update llmfoundry/models/layers/blocks.py

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* Changing some naming conventions, moving where tests are done, adding numerics test

* Updating import

* trying to fix tests

* trying to fix tests

* updating tests

* updating tests

* Update llmfoundry/models/mpt/configuration_mpt.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* Changing gating in tests to check for flash attn

---------

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 11, 2023
1 parent 34ec2f7 commit 410d5c7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
16 changes: 16 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
Expand Down Expand Up @@ -53,6 +58,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -105,6 +111,8 @@ def __init__(
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)

self.use_pad_tok_in_ffn = use_pad_tok_in_ffn

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -132,6 +140,14 @@ def forward(
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
batch_size, seq_len = m.size()[:2]
indices = None
if not self.use_pad_tok_in_ffn:
assert unpad_input is not None
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
10 changes: 10 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand All @@ -151,6 +153,7 @@ def __init__(
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -292,3 +295,10 @@ def _validate_config(self) -> None:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2'
)
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _attn_bias(
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)

return attn_bias, None
return attn_bias, attention_mask

def _apply_prefix_mask(self, attn_bias: torch.Tensor,
prefix_mask: torch.Tensor) -> torch.Tensor:
Expand Down
39 changes: 39 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,45 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
batched_output[1, :],
atol=1e-6 if attention_impl == 'torch' else 1e-8)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

if unpad_input is not None and pad_input is not None:
# Checking numerical precision with pad_token ffn
for block in mpt.transformer.blocks:
# Flip the padding usage in the model
block.use_pad_tok_in_ffn = not block.use_pad_tok_in_ffn

right_padding_output_pad_flipped = mpt(
right_padding_input_ids,
attention_mask=right_padding_attention_mask).logits
middle_padding_output_pad_flipped = mpt(
middle_padding_input_ids,
attention_mask=middle_padding_attention_mask).logits
left_padding_output_pad_flipped = mpt(
left_padding_input_ids,
attention_mask=left_padding_attention_mask).logits

pad_vs_unpad_rtol = 1e-5
pad_vs_unpad_atol = 1e-6
assert torch.allclose(right_padding_output[0, :3],
right_padding_output_pad_flipped[0, :3],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(middle_padding_output[0, [0, 1, 5]],
middle_padding_output_pad_flipped[0,
[0, 1, 5]],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(left_padding_output[0, 3:],
left_padding_output_pad_flipped[0, 3:],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)


@pytest.mark.parametrize('attention_impl', ['torch', 'triton'])
def test_advanced_mask_building(attention_impl: str):
Expand Down

0 comments on commit 410d5c7

Please sign in to comment.