Skip to content

Commit

Permalink
Make prepending of bos token configurable. (#1114)
Browse files Browse the repository at this point in the history
* make prepending of bos token configurable.

* address comments

* fix bug

Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
pacman100 and younesbelkada authored Dec 20, 2023
1 parent f100ca3 commit f2acd82
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
20 changes: 18 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class SFTTrainer(Trainer):
fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""

def __init__(
Expand Down Expand Up @@ -138,6 +140,7 @@ def __init__(
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -240,6 +243,8 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if dataset_kwargs is None:
dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
Expand All @@ -250,6 +255,7 @@ def make_inputs_require_grad(module, input, output):
formatting_func,
num_of_sequences,
chars_per_token,
**dataset_kwargs,
)
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
Expand All @@ -264,6 +270,7 @@ def make_inputs_require_grad(module, input, output):
formatting_func,
num_of_sequences,
chars_per_token,
**dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]
Expand Down Expand Up @@ -328,6 +335,8 @@ def _prepare_dataset(
formatting_func,
num_of_sequences,
chars_per_token,
append_concat_token=True,
add_special_tokens=True,
):
if dataset is None:
raise ValueError("The dataset should not be None")
Expand All @@ -338,7 +347,7 @@ def _prepare_dataset(

if not packing:
return self._prepare_non_packed_dataloader(
tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func
tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func, add_special_tokens
)

else:
Expand All @@ -350,10 +359,12 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
formatting_func,
append_concat_token,
add_special_tokens,
)

def _prepare_non_packed_dataloader(
self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None
self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None, add_special_tokens=True
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False
Expand All @@ -362,6 +373,7 @@ def _prepare_non_packed_dataloader(
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
add_special_tokens=add_special_tokens,
truncation=True,
padding=False,
max_length=max_seq_length,
Expand Down Expand Up @@ -398,6 +410,8 @@ def _prepare_packed_dataloader(
num_of_sequences,
chars_per_token,
formatting_func=None,
append_concat_token=True,
add_special_tokens=True,
):
if dataset_text_field is not None or formatting_func is not None:
if tokenizer is None:
Expand All @@ -413,6 +427,8 @@ def _prepare_packed_dataloader(
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
append_concat_token=append_concat_token,
add_special_tokens=add_special_tokens,
)

def data_generator(constant_length_iterator):
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class ConstantLengthDataset(IterableDataset):
Shuffle the examples before they are returned
append_concat_token ('bool', *optional*, defaults to True)
If true, appends `eos_token_id` at the end of each sample being packed.
add_special_tokens ('bool', *optional*, defaults to True)
If true, tokenizers adds special tokens to each sample being packed.
"""

def __init__(
Expand All @@ -376,6 +378,7 @@ def __init__(
eos_token_id=0,
shuffle=True,
append_concat_token=True,
add_special_tokens=True,
):
self.tokenizer = tokenizer

Expand All @@ -393,6 +396,7 @@ def __init__(
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.shuffle = shuffle
self.append_concat_token = append_concat_token
self.add_special_tokens = add_special_tokens
if formatting_func is None:
self.formatting_func = lambda x: x[dataset_text_field]
else:
Expand Down Expand Up @@ -426,7 +430,9 @@ def __iter__(self):
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
Expand Down

0 comments on commit f2acd82

Please sign in to comment.