diff --git a/docs/source/en/main_classes/data_collator.md b/docs/source/en/main_classes/data_collator.md index 74e653dd1185e9..e704bb747fe6e0 100644 --- a/docs/source/en/main_classes/data_collator.md +++ b/docs/source/en/main_classes/data_collator.md @@ -66,3 +66,8 @@ Examples of use can be found in the [example scripts](../examples) or [example n - numpy_mask_tokens - tf_mask_tokens - torch_mask_tokens + +## DataCollatorWithFlattening + +[[autodoc]] data.data_collator.DataCollatorWithFlattening + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bc6e786358b68d..05becb96e0b808 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -103,6 +103,7 @@ "DataCollatorForSOP", "DataCollatorForTokenClassification", "DataCollatorForWholeWordMask", + "DataCollatorWithFlattening", "DataCollatorWithPadding", "DefaultDataCollator", "default_data_collator", @@ -4764,6 +4765,7 @@ DataCollatorForSOP, DataCollatorForTokenClassification, DataCollatorForWholeWordMask, + DataCollatorWithFlattening, DataCollatorWithPadding, DefaultDataCollator, default_data_collator, diff --git a/src/transformers/data/__init__.py b/src/transformers/data/__init__.py index 1a8ef35ff439e4..8b675aae281f32 100644 --- a/src/transformers/data/__init__.py +++ b/src/transformers/data/__init__.py @@ -19,6 +19,7 @@ DataCollatorForSOP, DataCollatorForTokenClassification, DataCollatorForWholeWordMask, + DataCollatorWithFlattening, DataCollatorWithPadding, DefaultDataCollator, default_data_collator, diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index ce17f79ccfc88e..20a21318786c4a 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1611,3 +1611,38 @@ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]: ) & masked_indices[i] return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64) + + +@dataclass +class DataCollatorWithFlattening(DefaultDataCollator): + """ + Data collator used for padding free approach. Does the following: + + - concatate the entire mini batch into single long sequence [1, total_tokens] + - no padding will be added, returns `input_ids`, `labels` and `position_ids` + """ + + def __init__(self, *args, return_position_ids=True, **kwargs): + super().__init__(*args, **kwargs) + self.return_position_ids = return_position_ids + warnings.warn( + "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence." + "Make sure your attention computation is able to handle it!" + ) + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + is_labels_provided = "labels" in features[0] + ret = {"input_ids": [], "labels": []} + if self.return_position_ids: + ret.update({"position_ids": []}) + for idx in range(0, len(features)): + ret["input_ids"] += features[idx]["input_ids"] + if is_labels_provided: + ret["labels"] += [-100] + features[idx]["labels"][1:] + else: + ret["labels"] += [-100] + features[idx]["input_ids"][1:] + if self.return_position_ids: + ret["position_ids"] += list(range(len(features[idx]["input_ids"]))) + return default_data_collator([ret], return_tensors) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1742e419b4aaea..88dd99e6901d16 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -130,6 +130,56 @@ def _upad_input( ) +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, @@ -138,6 +188,7 @@ def _flash_attention_forward( query_length: int, is_causal: bool, dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, use_top_left_mask: bool = False, @@ -210,6 +261,34 @@ def _flash_attention_forward( **flash_kwargs, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # if position_ids is provided and check not all examples (row) contain only 1 sequence, + # then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 31810028ef4448..e3be8decbc6b52 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -415,6 +415,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 663582c8a72a83..fc7a38ed134d4f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -602,6 +602,7 @@ def forward( value_layer, attention_mask, query_length, + position_ids=position_ids, dropout=attn_dropout, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 80e97fe700b5ca..5bc1af3e7ec7a9 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -393,6 +393,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), is_causal=self.is_causal, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3115cee78f7677..ce76d1d1ec1b9d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -503,6 +503,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index dd814cd75fb112..93a60a49dbf34c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -382,6 +382,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 82320de79386b5..d2ee6e6b268ae0 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -488,6 +488,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), is_causal=self.is_causal, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index a56baf0653ecd3..74d49d5606c145 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -428,6 +428,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index f80453d3f7d990..1b23be39e5c05d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -501,6 +501,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=attn_dropout, softmax_scale=None, use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 76e3fbf514f6d6..90b815184b07a8 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -563,6 +563,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=attn_dropout, sliding_window=getattr(self.config, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 68923ed4052dd8..1ff8896ae5f901 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -429,6 +429,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d88b5c357e86da..54e91da3347dbc 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -508,6 +508,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index ea50a20edea8a8..3a3b6a9e05f117 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -606,6 +606,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index af532b139ca392..f2786f9df48a6b 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -404,6 +404,7 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), is_causal=self.is_causal, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 19a945aec52799..abe5ddea2c2511 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4327,6 +4327,78 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs" + # ensure left padding, to adapt for some models + if 0 in inputs_dict["attention_mask"][:, -1]: + inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) + dummy_attention_mask = inputs_dict["attention_mask"] + inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id + + model = ( + model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ) + .to(torch_device) + .eval() + ) + + # flatten + padfree_inputs_dict = { + k: v[dummy_attention_mask.bool()].unsqueeze(0) + for k, v in inputs_dict.items() + if not k == "attention_mask" + } + # add position_ids + padfree_inputs_dict["position_ids"] = ( + torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) + .long() + .unsqueeze(0) + .to(torch_device) + ) + + res_padded = model(**inputs_dict) + res_padfree = model(**padfree_inputs_dict) + + logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] + logits_padfree = res_padfree.logits[0] + + torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0) + # acceptable numerical instability + tol = torch.finfo(torch.float16).eps + torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol) + @is_pt_tf_cross_test def test_tf_from_pt_safetensors(self): for model_class in self.all_model_classes: diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index 36e1813258d1a3..8c1f593ff4bcb8 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -26,6 +26,7 @@ DataCollatorForSeq2Seq, DataCollatorForTokenClassification, DataCollatorForWholeWordMask, + DataCollatorWithFlattening, DataCollatorWithPadding, default_data_collator, is_tf_available, @@ -1531,6 +1532,24 @@ def test_data_collator_with_padding(self): batch = data_collator(features) self.assertEqual(batch["input_ids"].shape, (2, 8)) + def test_data_collator_with_flattening(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + + data_collator = DataCollatorWithFlattening(return_tensors="np") + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, (1, 16)) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertNotIn("attention_mask", batch) + self.assertIn("position_ids", batch) + self.assertEqual(batch["position_ids"].shape, (1, 16)) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + def test_data_collator_for_token_classification(self): tokenizer = BertTokenizer(self.vocab_file) features = [