diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6db9ff22ca..c077ccb535 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -12,6 +12,11 @@ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +try: + from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip +except: + unpad_input, pad_input = None, None + attn_config_defaults: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, @@ -53,6 +58,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, no_bias: bool = False, + use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): if attn_config is None: @@ -105,6 +111,8 @@ def __init__( self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + self.use_pad_tok_in_ffn = use_pad_tok_in_ffn + def forward( self, x: torch.Tensor, @@ -132,6 +140,14 @@ def forward( m = x if self.norm_2 is not None: m = self.norm_2(x) + batch_size, seq_len = m.size()[:2] + indices = None + if not self.use_pad_tok_in_ffn: + assert unpad_input is not None + m, indices, _, _ = unpad_input(m, attention_mask) n = self.ffn(m) + if not self.use_pad_tok_in_ffn: + assert pad_input is not None + n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) return x, attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 47fd5ac9e5..6013c96d0b 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -60,6 +60,7 @@ def __init__( init_config: Dict = init_config_defaults, fc_type: str = 'torch', tie_word_embeddings: bool = True, + use_pad_tok_in_ffn: bool = True, verbose: Optional[int] = None, **kwargs: Any, ): @@ -131,6 +132,7 @@ def __init__( See llmfoundry.models.utils.param_init_fns.py for info on other param init config options fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. + use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. """ self.d_model = d_model self.n_heads = n_heads @@ -151,6 +153,7 @@ def __init__( self.use_cache = use_cache self.init_config = init_config self.fc_type = fc_type + self.use_pad_tok_in_ffn = use_pad_tok_in_ffn if verbose is not None: warnings.warn( DeprecationWarning( @@ -292,3 +295,10 @@ def _validate_config(self) -> None: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias + if not self.use_pad_tok_in_ffn: + try: + from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip + except: + raise ImportError( + 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2' + ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e2d2ee6fbc..8c134e2b9f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -419,7 +419,7 @@ def _attn_bias( attn_bias = attn_bias.masked_fill( ~attention_mask.view(-1, 1, 1, s_k), min_val) - return attn_bias, None + return attn_bias, attention_mask def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 98a556f534..12d7b3de37 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -831,6 +831,45 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, batched_output[1, :], atol=1e-6 if attention_impl == 'torch' else 1e-8) + try: + from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip + except: + unpad_input, pad_input = None, None + + if unpad_input is not None and pad_input is not None: + # Checking numerical precision with pad_token ffn + for block in mpt.transformer.blocks: + # Flip the padding usage in the model + block.use_pad_tok_in_ffn = not block.use_pad_tok_in_ffn + + right_padding_output_pad_flipped = mpt( + right_padding_input_ids, + attention_mask=right_padding_attention_mask).logits + middle_padding_output_pad_flipped = mpt( + middle_padding_input_ids, + attention_mask=middle_padding_attention_mask).logits + left_padding_output_pad_flipped = mpt( + left_padding_input_ids, + attention_mask=left_padding_attention_mask).logits + + pad_vs_unpad_rtol = 1e-5 + pad_vs_unpad_atol = 1e-6 + assert torch.allclose(right_padding_output[0, :3], + right_padding_output_pad_flipped[0, :3], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol) + + assert torch.allclose(middle_padding_output[0, [0, 1, 5]], + middle_padding_output_pad_flipped[0, + [0, 1, 5]], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol) + + assert torch.allclose(left_padding_output[0, 3:], + left_padding_output_pad_flipped[0, 3:], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol) + @pytest.mark.parametrize('attention_impl', ['torch', 'triton']) def test_advanced_mask_building(attention_impl: str):