From 7cab73b57fe80283dd1a0166d1e23d9e602e7c4c Mon Sep 17 00:00:00 2001 From: ylfeng Date: Wed, 18 Sep 2024 21:31:56 +0800 Subject: [PATCH] 1. support flat_packing 2. fix knapsack, may cause #5443 3. avoid supervised examples wrongly truncation --- src/llamafactory/data/__init__.py | 3 +- src/llamafactory/data/collator.py | 39 ++++++- src/llamafactory/data/preprocess.py | 8 +- .../data/processors/processor_utils.py | 6 + .../data/processors/supervised.py | 103 +++++++++++------- src/llamafactory/hparams/data_args.py | 11 ++ src/llamafactory/train/sft/workflow.py | 39 +++++-- 7 files changed, 155 insertions(+), 54 deletions(-) diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index ea1a02f20c..2161422793 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -17,17 +17,18 @@ MultiModalDataCollatorForSeq2Seq, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask, + SFTDataCollatorWithFlattingPacking, ) from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer - __all__ = [ "KTODataCollatorWithPadding", "MultiModalDataCollatorForSeq2Seq", "PairwiseDataCollatorWithPadding", "SFTDataCollatorWith4DAttentionMask", + "SFTDataCollatorWithFlattingPacking", "Role", "split_dataset", "get_dataset", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 92d86cc754..75fb937b88 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -19,8 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch -from transformers import DataCollatorForSeq2Seq - +from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase if TYPE_CHECKING: from transformers import ProcessorMixin @@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso return features +@dataclass +class SFTDataCollatorWithFlattingPacking(DefaultDataCollator): + r""" + Data collator for flatting packing. + """ + + tokenizer: PreTrainedTokenizerBase = None + label_pad_token_id: int = -100 + template: Optional["Template"] = None + processor: Optional["ProcessorMixin"] = None + return_position_ids: bool = True + + def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]: + # todo: not support multi-model + if return_tensors is None: + return_tensors = self.return_tensors + is_labels_provided = "labels" in features[0] + ret = {"input_ids": [], "labels": []} + if self.return_position_ids: + ret.update({"position_ids": []}) + for instances in features: + for input_ids, labels in zip(instances["input_ids"], instances["labels"]): + ret["input_ids"] += input_ids + if is_labels_provided: + ret["labels"] += [self.label_pad_token_id] + labels[1:] + else: + ret["labels"] += [self.label_pad_token_id] + input_ids[1:] + if self.return_position_ids: + ret["position_ids"] += list(range(len(input_ids))) + + assert len(ret["input_ids"]) == len(ret["labels"]) + + features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors) + return features + + @dataclass class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r""" diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 9f015b3823..860137aaf7 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -22,10 +22,10 @@ preprocess_packed_supervised_dataset, preprocess_supervised_dataset, print_supervised_dataset_example, + print_flatting_supervised_dataset_example, ) from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example - if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin @@ -78,8 +78,10 @@ def __init__(self, data, **kwargs): processor=processor, data_args=data_args, ) - - print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + if data_args.packing and data_args.flat_packing: + print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer) + else: + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( preprocess_pairwise_dataset, diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 8e13d100bc..b7297df34f 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: r""" An efficient greedy algorithm with binary search for the knapsack problem. """ + # filter out numbers that are larger than the capacity + numbers = [number for number in numbers if number <= capacity] numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] @@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: remaining_capacity -= numbers[index] # update the remaining capacity current_knapsack.append(numbers.pop(index)) # add the number to knapsack + # avoid endless loop + if remaining_capacity == capacity: + break + knapsacks.append(current_knapsack) return knapsacks diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 666256407a..84c7ed441b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple @@ -19,7 +19,6 @@ from ...extras.logging import get_logger from .processor_utils import greedy_knapsack, infer_seqlen - if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin @@ -27,7 +26,6 @@ from ..mm_plugin import ImageInput, VideoInput from ..template import Template - logger = get_logger(__name__) @@ -53,13 +51,16 @@ def _encode_supervised_example( encoded_pairs = encoded_pairs[::-1] # high priority for last turns for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): - if total_length >= cutoff_len: + if total_length >= cutoff_len and cutoff_len > 0: break - source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) - source_ids = source_ids[:source_len] - target_ids = target_ids[:target_len] - total_length += source_len + target_len + if cutoff_len > 0: + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + else: + source_len, target_len = len(source_ids), len(target_ids) if train_on_prompt: source_label = source_ids @@ -112,7 +113,7 @@ def preprocess_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len, + cutoff_len=data_args.cutoff_len if data_args.allow_truncation else 0, train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) @@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset( processor: Optional["ProcessorMixin"], data_args: "DataArguments", ) -> Dict[str, List[Any]]: - # TODO: use `position_ids` to achieve packing # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` valid_num = 0 + invalid_num = 0 batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], [] lengths = [] length2indexes = defaultdict(list) + + # reserved for the padding token / flat_packing don't need + num_reserved = 0 if data_args.flat_packing else 1 for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) @@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token + cutoff_len=data_args.cutoff_len - num_reserved if data_args.allow_truncation else 0, train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) length = len(input_ids) - if length > data_args.cutoff_len: - logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) + if length > data_args.cutoff_len - num_reserved: + invalid_num += 1 else: lengths.append(length) length2indexes[length].append(valid_num) @@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset( batch_videos.append(examples["_videos"][i] or []) valid_num += 1 + if invalid_num > 0: + logger.warning( + "Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved) + ) + model_inputs = defaultdict(list) - knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token for knapsack in knapsacks: packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_images, packed_videos = [], [] - for i, length in enumerate(knapsack): - index = length2indexes[length].pop() - packed_input_ids += batch_input_ids[index] - packed_labels += batch_labels[index] - packed_images += batch_images[index] - packed_videos += batch_videos[index] - if data_args.neat_packing: - packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 - else: - packed_attention_masks += [1] * len(batch_input_ids[index]) - - if len(packed_input_ids) < data_args.cutoff_len: - pad_length = data_args.cutoff_len - len(packed_input_ids) - packed_input_ids += [tokenizer.pad_token_id] * pad_length - packed_labels += [IGNORE_INDEX] * pad_length - if data_args.neat_packing: - packed_attention_masks += [0] * pad_length - else: - packed_attention_masks += [1] * pad_length # more efficient flash_attn - - if len(packed_input_ids) != data_args.cutoff_len: - raise ValueError("The length of packed example should be identical to the cutoff length.") + + if data_args.flat_packing: + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids.append(batch_input_ids[index]) + packed_labels.append(batch_labels[index]) + packed_images.append(batch_images[index]) + packed_videos.append(batch_videos[index]) + else: + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] + packed_images += batch_images[index] + packed_videos += batch_videos[index] + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) + + # flat_packing don't need attention masks + if len(packed_input_ids) < data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn + + # flatting packing don't need pad + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") + model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) model_inputs["images"].append(packed_images or None) model_inputs["videos"].append(packed_videos or None) @@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: " print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) + + +def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"]))) + input_ids = list(itertools.chain(*example["input_ids"])) + print("input_ids:\n{}".format(input_ids)) + print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False))) + print("label_ids:\n{}".format(list(itertools.chain(*example["labels"])))) + print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False)) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1adcf2d0df..7c51060aa2 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -105,6 +105,14 @@ class DataArguments: default=False, metadata={"help": "Enable sequence packing without cross-attention."}, ) + flat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing with flattening, need flash atten."} + ) + allow_truncation: bool = field( + default=False, + metadata={"help": "Allow truncation when processing supervised examples."} + ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, @@ -148,3 +156,6 @@ def split_arg(arg): if self.mask_history and self.train_on_prompt: raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") + + if self.neat_packing and self.flat_packing: + raise ValueError("`neat_packing` is incompatible with `flat_packing`.") diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 43a9aef16f..9df02be84b 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,21 +17,24 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer +from ...data import SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWithFlattingPacking, get_dataset, \ + get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss +from ...extras.logging import get_logger from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .trainer import CustomSeq2SeqTrainer - if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +logger = get_logger(__name__) + def run_sft( model_args: "ModelArguments", @@ -50,15 +53,29 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = SFTDataCollatorWith4DAttentionMask( - template=template, - pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention - label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, - block_diag_attn=model_args.block_diag_attn, - attn_implementation=getattr(model.config, "_attn_implementation", None), - compute_dtype=model_args.compute_dtype, - **tokenizer_module, - ) + if ( + data_args.packing and + data_args.flat_packing and + (getattr(model.config, "_attn_implementation", None) != "flash_attention_2") + ): + logger.warning("The `flat_packing` only support `flash_attention_2`! Maybe cause Out of memory!") + + if (data_args.packing and data_args.flat_packing): + data_collator = SFTDataCollatorWithFlattingPacking( + template=template, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + **tokenizer_module, + ) + else: + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + **tokenizer_module, + ) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len