Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support call sft training with clone PaddleNLP #9516

Merged
merged 25 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: 'slm/model_zoo/gpt-3'
exclude: 'slm/model_zoo/gpt-3;csrc/third_party'
repos:
# For Python files
- repo: https://github.com/psf/black.git
Expand Down Expand Up @@ -61,4 +61,4 @@ repos:
entry: python scripts/codestyle/check_dead_links.py
language: python
files: \.(md|markdown|rst)$
pass_filenames: true
pass_filenames: true
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py
```

更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
另外我们还提供了快速微调方式, 无需 clone 源代码:

```python
from paddlenlp.trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("ZHUI/alpaca_demo", split="train")

training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
```

更多 PaddleNLP 内容可参考:

Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/trl/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
129 changes: 129 additions & 0 deletions paddlenlp/trl/extras/dataset_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# https://github.com/huggingface/trl/blob/c10cc8995b6fd45f3a876ec98cade97251abe733/trl/extras/dataset_formatting.py#L74

import logging
from typing import Callable, Literal, Optional, Union

from datasets import Dataset, Value

from ...transformers import AutoTokenizer

FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
"paddlenlp": {"src": Value(dtype="string", id=None), "tgt": Value(dtype="string", id=None)},
}


def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""

def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return output_texts

Check warning on line 42 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L37-L42

Added lines #L37 - L42 were not covered by tests
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)

Check warning on line 44 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L44

Added line #L44 was not covered by tests

return format_dataset

Check warning on line 46 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L46

Added line #L46 was not covered by tests


def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""

def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [

Check warning on line 59 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L55-L59

Added lines #L55 - L59 were not covered by tests
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts

Check warning on line 64 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L63-L64

Added lines #L63 - L64 were not covered by tests
else:
converted_sample = [

Check warning on line 66 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L66

Added line #L66 was not covered by tests
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)

Check warning on line 70 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L70

Added line #L70 was not covered by tests

return format_dataset

Check warning on line 72 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L72

Added line #L72 was not covered by tests


def paddlenlp_instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""

def format_dataset(examples):
if isinstance(examples["src"], list):
output_texts = []
for i in range(len(examples["src"])):
converted_sample = [

Check warning on line 85 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L81-L85

Added lines #L81 - L85 were not covered by tests
{"role": "user", "content": examples["src"][i]},
{"role": "assistant", "content": examples["tgt"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts

Check warning on line 90 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L89-L90

Added lines #L89 - L90 were not covered by tests
else:
converted_sample = [

Check warning on line 92 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L92

Added line #L92 was not covered by tests
{"role": "user", "content": examples["src"]},
{"role": "assistant", "content": examples["tgt"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)

Check warning on line 96 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L96

Added line #L96 was not covered by tests

return format_dataset

Check warning on line 98 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L98

Added line #L98 was not covered by tests


def get_formatting_func_from_dataset(dataset: Union[Dataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]

Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting

Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
elif dataset.features == FORMAT_MAPPING["paddlenlp"]:
return paddlenlp_instructions_formatting_function(tokenizer)

Check warning on line 127 in paddlenlp/trl/extras/dataset_formatting.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/extras/dataset_formatting.py#L115-L127

Added lines #L115 - L127 were not covered by tests

return None
15 changes: 14 additions & 1 deletion paddlenlp/trl/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Optional

from paddlenlp.trainer import TrainingArguments
from paddlenlp.trainer.trainer_utils import IntervalStrategy
Expand Down Expand Up @@ -49,6 +49,19 @@ class SFTConfig(TrainingArguments):
default="",
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
)
dataset_text_field: str = "text"
learning_rate: float = 2.0e-5
max_seq_length: int = field(
default=2048,
metadata={
"help": "The maximum length that model input tokens can have. When Zero Padding is set to True, it's also the maximum length for Zero Padding data stream"
},
)
dataset_num_proc: Optional[int] = None
dataset_batch_size: int = 1000
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_kwargs: Optional[dict[str, Any]] = None
eval_packing: Optional[bool] = None

def __post_init__(self):
super().__post_init__()
Expand Down
Loading