Skip to content

Commit

Permalink
1. support flatting_packing
Browse files Browse the repository at this point in the history
2. update mistral format function call
3. fix knapsack, may cause #5443
4. avoid supervised examples wrongly truncation #5426
  • Loading branch information
AlongWY committed Sep 17, 2024
1 parent 1a3e654 commit 558b983
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 115 deletions.
1 change: 0 additions & 1 deletion src/llamafactory/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui


USAGE = (
"-" * 70
+ "\n"
Expand Down
3 changes: 2 additions & 1 deletion src/llamafactory/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
SFTDataCollatorWithFlattingPacking,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer


__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"SFTDataCollatorWithFlattingPacking",
"Role",
"split_dataset",
"get_dataset",
Expand Down
39 changes: 37 additions & 2 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence

import torch
from transformers import DataCollatorForSeq2Seq

from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase

if TYPE_CHECKING:
from transformers import ProcessorMixin
Expand Down Expand Up @@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
return features


@dataclass
class SFTDataCollatorWithFlattingPacking(DefaultDataCollator):
r"""
Data collator for flatting packing.
"""

tokenizer: PreTrainedTokenizerBase = None
label_pad_token_id: int = -100
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
return_position_ids: bool = True

def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]:
# todo: not support multi-model
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for instances in features:
for input_ids, labels in zip(instances["input_ids"], instances["labels"]):
ret["input_ids"] += input_ids
if is_labels_provided:
ret["labels"] += [self.label_pad_token_id] + labels[1:]
else:
ret["labels"] += [self.label_pad_token_id] + input_ids[1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(input_ids)))

assert len(ret["input_ids"]) == len(ret["labels"])

features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors)
return features


@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Expand Down
46 changes: 45 additions & 1 deletion src/llamafactory/data/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .data_utils import SLOTS
from .tool_utils import get_tool_utils


if TYPE_CHECKING:
from .tool_utils import FunctionCall

Expand Down Expand Up @@ -129,6 +128,51 @@ def apply(self, **kwargs) -> SLOTS:
return elements


@dataclass
class MistralFunctionFormatter(Formatter):
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]

for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))

except json.JSONDecodeError:
functions = []

elements = []
for name, arguments in functions:
elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""")
elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"]

return elements


@dataclass
class MistralObservationFormatter(Formatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots

@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
tool_results: List[Tuple[str, str]]
try:
tool_results = [json.dumps(result) for result in json.loads(content)]
except json.JSONDecodeError:
tool_results = []

elements = []
for content in tool_results:
elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]")
return ["".join(elements)]


@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
Expand Down
8 changes: 5 additions & 3 deletions src/llamafactory/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
print_flatting_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

Expand Down Expand Up @@ -78,8 +78,10 @@ def __init__(self, data, **kwargs):
processor=processor,
data_args=data_args,
)

print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
if data_args.packing and data_args.flatting_packing:
print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer)
else:
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
Expand Down
6 changes: 6 additions & 0 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
# filter out numbers that are larger than the capacity
numbers = [number for number in numbers if number <= capacity]
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []

Expand All @@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack

# avoid endless loop
if remaining_capacity == capacity:
break

knapsacks.append(current_knapsack)

return knapsacks
Expand Down
98 changes: 59 additions & 39 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template


logger = get_logger(__name__)


Expand All @@ -48,18 +46,12 @@ def _encode_supervised_example(
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break

source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
source_len = len(source_ids)
target_len = len(target_ids)

if train_on_prompt:
source_label = source_ids
Expand Down Expand Up @@ -132,13 +124,16 @@ def preprocess_packed_supervised_dataset(
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
invalid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)

# reserved for the padding token / flatting_packing don't need
num_reserved = 0 if data_args.flatting_packing else 1
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
Expand All @@ -154,13 +149,13 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
cutoff_len=data_args.cutoff_len - num_reserved,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
if length > data_args.cutoff_len - num_reserved:
invalid_num += 1
else:
lengths.append(length)
length2indexes[length].append(valid_num)
Expand All @@ -170,36 +165,52 @@ def preprocess_packed_supervised_dataset(
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1

if invalid_num > 0:
logger.warning(
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
)

model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")

if data_args.flatting_packing:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids.append(batch_input_ids[index])
packed_labels.append(batch_labels[index])
packed_images.append(batch_images[index])
packed_videos.append(batch_videos[index])
else:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

# flatting_packing don't need attention masks
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

# flatting packing don't need pad
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["attention_mask"].append(packed_attention_masks)

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
Expand All @@ -213,3 +224,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))


def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"])))
input_ids = list(itertools.chain(*example["input_ids"]))
print("input_ids:\n{}".format(input_ids))
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False)))
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"]))))
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False))
Loading

0 comments on commit 558b983

Please sign in to comment.