Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make prepending of bos token configurable. #1114

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class SFTTrainer(Trainer):
fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""

def __init__(
Expand Down Expand Up @@ -138,6 +140,7 @@ def __init__(
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -240,6 +243,8 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if dataset_kwargs is None:
dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
Expand All @@ -250,6 +255,7 @@ def make_inputs_require_grad(module, input, output):
formatting_func,
num_of_sequences,
chars_per_token,
**dataset_kwargs,
)
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
Expand All @@ -264,6 +270,7 @@ def make_inputs_require_grad(module, input, output):
formatting_func,
num_of_sequences,
chars_per_token,
**dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]
Expand Down Expand Up @@ -328,6 +335,8 @@ def _prepare_dataset(
formatting_func,
num_of_sequences,
chars_per_token,
append_concat_token=True,
add_special_tokens=True,
):
if dataset is None:
raise ValueError("The dataset should not be None")
Expand All @@ -338,7 +347,7 @@ def _prepare_dataset(

if not packing:
return self._prepare_non_packed_dataloader(
tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func
tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func, add_special_tokens
)

else:
Expand All @@ -350,10 +359,12 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
formatting_func,
append_concat_token,
add_special_tokens,
)

def _prepare_non_packed_dataloader(
self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None
self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None, add_special_tokens=True
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False
Expand All @@ -362,6 +373,7 @@ def _prepare_non_packed_dataloader(
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
add_special_tokens=add_special_tokens,
truncation=True,
padding=False,
max_length=max_seq_length,
Expand Down Expand Up @@ -398,6 +410,8 @@ def _prepare_packed_dataloader(
num_of_sequences,
chars_per_token,
formatting_func=None,
append_concat_token=True,
add_special_tokens=True,
):
if dataset_text_field is not None or formatting_func is not None:
if tokenizer is None:
Expand All @@ -413,6 +427,8 @@ def _prepare_packed_dataloader(
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
append_concat_token=append_concat_token,
add_special_tokens=add_special_tokens,
)

def data_generator(constant_length_iterator):
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class ConstantLengthDataset(IterableDataset):
Shuffle the examples before they are returned
append_concat_token ('bool', *optional*, defaults to True)
If true, appends `eos_token_id` at the end of each sample being packed.
add_special_tokens ('bool', *optional*, defaults to True)
If true, tokenizers adds special tokens to each sample being packed.
"""

def __init__(
Expand All @@ -376,6 +378,7 @@ def __init__(
eos_token_id=0,
shuffle=True,
append_concat_token=True,
add_special_tokens=True,
):
self.tokenizer = tokenizer

Expand All @@ -393,6 +396,7 @@ def __init__(
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.shuffle = shuffle
self.append_concat_token = append_concat_token
self.add_special_tokens = add_special_tokens
if formatting_func is None:
self.formatting_func = lambda x: x[dataset_text_field]
else:
Expand Down Expand Up @@ -426,7 +430,9 @@ def __iter__(self):
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
Expand Down
Loading