Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flag to not pass PAD tokens in ffwd #775

Merged
merged 21 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llmfoundry/models/layers/blocks.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
Expand Down Expand Up @@ -53,6 +58,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -105,6 +111,8 @@ def __init__(
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)

self.use_pad_tok_in_ffn = use_pad_tok_in_ffn

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -132,6 +140,14 @@ def forward(
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
batch_size, seq_len = m.size()[:2]
indices = None
if not self.use_pad_tok_in_ffn:
assert unpad_input is not None
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
10 changes: 10 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand All @@ -151,6 +153,7 @@ def __init__(
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -292,3 +295,10 @@ def _validate_config(self) -> None:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2'
)
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _attn_bias(
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)

return attn_bias, None
return attn_bias, attention_mask

def _apply_prefix_mask(self, attn_bias: torch.Tensor,
prefix_mask: torch.Tensor) -> torch.Tensor:
Expand Down
39 changes: 39 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,45 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
batched_output[1, :],
atol=1e-6 if attention_impl == 'torch' else 1e-8)

try:
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

if unpad_input is not None and pad_input is not None:
# Checking numerical precision with pad_token ffn
for block in mpt.transformer.blocks:
# Flip the padding usage in the model
block.use_pad_tok_in_ffn = not block.use_pad_tok_in_ffn

right_padding_output_pad_flipped = mpt(
right_padding_input_ids,
attention_mask=right_padding_attention_mask).logits
middle_padding_output_pad_flipped = mpt(
middle_padding_input_ids,
attention_mask=middle_padding_attention_mask).logits
left_padding_output_pad_flipped = mpt(
left_padding_input_ids,
attention_mask=left_padding_attention_mask).logits

pad_vs_unpad_rtol = 1e-5
pad_vs_unpad_atol = 1e-6
assert torch.allclose(right_padding_output[0, :3],
right_padding_output_pad_flipped[0, :3],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(middle_padding_output[0, [0, 1, 5]],
middle_padding_output_pad_flipped[0,
[0, 1, 5]],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(left_padding_output[0, 3:],
left_padding_output_pad_flipped[0, 3:],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)


@pytest.mark.parametrize('attention_impl', ['torch', 'triton'])
def test_advanced_mask_building(attention_impl: str):
Expand Down
Loading