diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index beb8aa7f558d..1f664f52372f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: 'slm/model_zoo/gpt-3' +exclude: 'slm/model_zoo/gpt-3;csrc/third_party' repos: # For Python files - repo: https://github.com/psf/black.git @@ -61,4 +61,4 @@ repos: entry: python scripts/codestyle/check_dead_links.py language: python files: \.(md|markdown|rst)$ - pass_filenames: true \ No newline at end of file + pass_filenames: true diff --git a/README.md b/README.md index 7151207c6554..70a508d1b495 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,22 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ``` 更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。 +另外我们还提供了快速微调方式, 无需 clone 源代码: + +```python +from paddlenlp.trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +dataset = load_dataset("ZHUI/alpaca_demo", split="train") + +training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu") +trainer = SFTTrainer( + args=training_args, + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, +) +trainer.train() +``` 更多 PaddleNLP 内容可参考: diff --git a/paddlenlp/trl/extras/__init__.py b/paddlenlp/trl/extras/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/paddlenlp/trl/extras/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlenlp/trl/extras/dataset_formatting.py b/paddlenlp/trl/extras/dataset_formatting.py new file mode 100644 index 000000000000..786bcf80c3db --- /dev/null +++ b/paddlenlp/trl/extras/dataset_formatting.py @@ -0,0 +1,129 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# https://github.com/huggingface/trl/blob/c10cc8995b6fd45f3a876ec98cade97251abe733/trl/extras/dataset_formatting.py#L74 + +import logging +from typing import Callable, Literal, Optional, Union + +from datasets import Dataset, Value + +from ...transformers import AutoTokenizer + +FORMAT_MAPPING = { + "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, + "paddlenlp": {"src": Value(dtype="string", id=None), "tgt": Value(dtype="string", id=None)}, +} + + +def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): + r""" + return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples[messages_field][0], list): + output_texts = [] + for i in range(len(examples[messages_field])): + output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) + return output_texts + else: + return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) + + return format_dataset + + +def instructions_formatting_function(tokenizer: AutoTokenizer): + r""" + return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples["prompt"], list): + output_texts = [] + for i in range(len(examples["prompt"])): + converted_sample = [ + {"role": "user", "content": examples["prompt"][i]}, + {"role": "assistant", "content": examples["completion"][i]}, + ] + output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) + return output_texts + else: + converted_sample = [ + {"role": "user", "content": examples["prompt"]}, + {"role": "assistant", "content": examples["completion"]}, + ] + return tokenizer.apply_chat_template(converted_sample, tokenize=False) + + return format_dataset + + +def paddlenlp_instructions_formatting_function(tokenizer: AutoTokenizer): + r""" + return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples["src"], list): + output_texts = [] + for i in range(len(examples["src"])): + converted_sample = [ + {"role": "user", "content": examples["src"][i]}, + {"role": "assistant", "content": examples["tgt"][i]}, + ] + output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) + return output_texts + else: + converted_sample = [ + {"role": "user", "content": examples["src"]}, + {"role": "assistant", "content": examples["tgt"]}, + ] + return tokenizer.apply_chat_template(converted_sample, tokenize=False) + + return format_dataset + + +def get_formatting_func_from_dataset(dataset: Union[Dataset], tokenizer: AutoTokenizer) -> Optional[Callable]: + r""" + Finds the correct formatting function based on the dataset structure. Currently supported datasets are: + - `ChatML` with [{"role": str, "content": str}] + - `instruction` with [{"prompt": str, "completion": str}] + + Args: + dataset (Dataset): User dataset + tokenizer (AutoTokenizer): Tokenizer used for formatting + + Returns: + Callable: Formatting function if the dataset format is supported else None + """ + if isinstance(dataset, Dataset): + if "messages" in dataset.features: + if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "messages") + if "conversations" in dataset.features: + if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "conversations") + elif dataset.features == FORMAT_MAPPING["instruction"]: + logging.info("Formatting dataset with instruction format") + return instructions_formatting_function(tokenizer) + elif dataset.features == FORMAT_MAPPING["paddlenlp"]: + return paddlenlp_instructions_formatting_function(tokenizer) + + return None diff --git a/paddlenlp/trl/sft_config.py b/paddlenlp/trl/sft_config.py index 56315e71ecf7..283244efbb5b 100644 --- a/paddlenlp/trl/sft_config.py +++ b/paddlenlp/trl/sft_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from paddlenlp.trainer import TrainingArguments from paddlenlp.trainer.trainer_utils import IntervalStrategy @@ -49,6 +49,19 @@ class SFTConfig(TrainingArguments): default="", metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"}, ) + dataset_text_field: str = "text" + learning_rate: float = 2.0e-5 + max_seq_length: int = field( + default=2048, + metadata={ + "help": "The maximum length that model input tokens can have. When Zero Padding is set to True, it's also the maximum length for Zero Padding data stream" + }, + ) + dataset_num_proc: Optional[int] = None + dataset_batch_size: int = 1000 + model_init_kwargs: Optional[dict[str, Any]] = None + dataset_kwargs: Optional[dict[str, Any]] = None + eval_packing: Optional[bool] = None def __post_init__(self): super().__post_init__() diff --git a/paddlenlp/trl/sft_trainer.py b/paddlenlp/trl/sft_trainer.py index 58f1a9a2aa8c..47434ab15d45 100644 --- a/paddlenlp/trl/sft_trainer.py +++ b/paddlenlp/trl/sft_trainer.py @@ -13,27 +13,292 @@ # limitations under the License. from __future__ import annotations -from typing import Dict, Optional +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union +import datasets import numpy as np import paddle import paddle.distributed as dist +import paddle.nn as nn +from datasets import Dataset from paddle.distributed import fleet from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler -from paddlenlp.trainer import Trainer -from paddlenlp.trainer.trainer_utils import has_length -from paddlenlp.utils.log import logger +from ..data import DataCollator, DataCollatorForSeq2Seq +from ..trainer import Trainer +from ..trainer.trainer_callback import TrainerCallback +from ..trainer.trainer_utils import EvalPrediction, has_length +from ..transformers import AutoModelForCausalLM, AutoTokenizer +from ..transformers.model_utils import PretrainedModel +from ..transformers.tokenizer_utils import PretrainedTokenizer +from ..utils.log import logger +from .extras.dataset_formatting import get_formatting_func_from_dataset +from .sft_config import SFTConfig __all__ = ["SFTTrainer"] class SFTTrainer(Trainer): - def __init__(self, do_generation: bool, gen_args, data_args, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + model: Union[PretrainedModel, nn.Layer] = None, + criterion: nn.Layer = None, + args: SFTConfig = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Union[Dataset, Dict[str, Dataset]] = None, + tokenizer: Optional[PretrainedTokenizer] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), + preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, + do_generation: bool = False, + gen_args=None, + data_args=None, + formatting_func: Optional[Callable] = None, + ): + self.do_generation = do_generation self.gen_args = gen_args self.data_args = data_args + if self.do_generation: + assert gen_args is not None + assert data_args is not None + + if args is None: + output_dir = "tmp_trainer" + warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.") + args = SFTConfig(output_dir=output_dir) + elif args is not None and args.__class__.__name__ == "TrainingArguments": + args_as_dict = args.to_dict() + # Manually copy token values as TrainingArguments.to_dict() redacts them + args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) + args = SFTConfig(**args_as_dict) + + if getattr(args, "model_init_kwargs", None) is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype") + if dtype is not None: + # Convert to `paddle.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(paddle, dtype) + if dtype != "auto" and not isinstance(dtype, paddle.dtype): + raise ValueError( + f"Invalid `dtype` passed to the SFTConfig. Expected a string with either `paddle.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + + name_or_path = None + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + name_or_path = model + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if tokenizer is None: + if name_or_path is not None: + tokenizer = AutoTokenizer.from_pretrained(name_or_path) + else: + raise ValueError("Please pass tokenizer") + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.max_seq_length is None: + # to overcome some issues with broken tokenizers + args.max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {args.max_seq_length}" + ) + + self.dataset_num_proc = args.dataset_num_proc + self.dataset_batch_size = args.dataset_batch_size + + if args.dataset_kwargs is None: + args.dataset_kwargs = {} + + if formatting_func is None: + # check if dataset has ChatML format or instruction format and is supported + # if not stays None + formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + # if a template is detected, we don't need to add special tokens again + if formatting_func is not None: + args.dataset_kwargs["add_special_tokens"] = False + + # Pre-process the datasets only once per node. The remaining processes will use the cache. + with args.main_process_first(): + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + args.dataset_text_field, + args.max_seq_length, + formatting_func, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **args.dataset_kwargs, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + args.dataset_text_field, + args.max_seq_length, + formatting_func, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **args.dataset_kwargs, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if data_collator is None: + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer) + + super().__init__( + model=model, + criterion=criterion, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + def _prepare_dataset( + self, + dataset, + tokenizer, + dataset_text_field: str, + max_seq_length, + formatting_func: Optional[Callable], + remove_unused_columns=True, + add_special_tokens=True, + skip_prepare_dataset=False, + ): + + if dataset is None: + raise ValueError("The dataset should not be None") + + if skip_prepare_dataset: + return dataset + + # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is + # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset + column_names = ( + dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + ) + if column_names and "input_ids" in column_names: + if formatting_func is not None: + warnings.warn( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored." + ) + + def formatting_func(x): + return x["input_ids"] + + return dataset + + # check if torch dataset / dataloader and do nothing + # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check + if isinstance(dataset, (paddle.io.IterableDataset, paddle.io.Dataset)) and not isinstance( + dataset, datasets.IterableDataset + ): + return dataset + + return self._prepare_non_packed_dataloader( + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func, + add_special_tokens, + remove_unused_columns, + ) + + def _prepare_non_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field: str, + max_seq_length, + formatting_func: Optional[Callable] = None, + add_special_tokens=True, + remove_unused_columns=True, + ): + + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field] if formatting_func is None else formatting_func(element), + add_special_tokens=add_special_tokens, + truncation=True, + padding=False, + max_length=max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + + if formatting_func is not None and not isinstance(formatting_func(element), list): + raise ValueError( + "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." + ) + labels = [] + if tokenizer.pad_token_id is not None: + # raise ValueError(type(outputs["input_ids"])) + if isinstance(outputs["input_ids"][0], list): + for x in outputs["input_ids"]: + sublabels = [] + for y in x: + sublabels.append(-100 if y == tokenizer.pad_token_id else y) + sublabels.append(-100) + sublabels = sublabels[1:] + labels.append(sublabels) + else: + for x in outputs["input_ids"]: + labels.append(-100 if x == tokenizer.pad_token_id else x) + labels.append(-100) + labels = labels[1:] + + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], "labels": labels} + + signature_columns = ["input_ids", "labels", "attention_mask"] + + if dataset.column_names is not None: # None for IterableDataset + extra_columns = list(set(dataset.column_names) - set(signature_columns)) + else: + extra_columns = [] + + if not remove_unused_columns and len(extra_columns) > 0: + warnings.warn( + "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " + f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." + ) + + map_kwargs = { + "batched": True, + "remove_columns": dataset.column_names if remove_unused_columns else None, + "batch_size": self.dataset_batch_size, + } + if isinstance(dataset, datasets.Dataset): + map_kwargs["num_proc"] = self.dataset_num_proc # this arg is not available for IterableDataset + tokenized_dataset = dataset.map(tokenize, **map_kwargs) + + print(tokenized_dataset[0]) + return tokenized_dataset def prediction_step( self,