Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extra BOS token in front of response for some tokenizers #1003

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,29 @@ def _slice_chat_formatted_example(
return prompt, response


def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str,
text_target: str) -> TokenizedExample:
"""Tokenizes the prompt and response using the provided tokenizer.

Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
prompt (str): The prompt to tokenize.
response (str): The response to tokenize.

Returns:
TokenizedExample: The tokenized example.
"""
tokenized_sample = tokenizer(text=text, text_target=text_target)

# Remove the BOS token from the start of the labels if it was automatically added
if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
if tokenizer.bos_token_id is not None and tokenized_sample['labels'][
0] == tokenizer.bos_token_id:
tokenized_sample['labels'] = tokenized_sample['labels'][1:]

return tokenized_sample


def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
Expand Down Expand Up @@ -246,7 +269,11 @@ def _tokenize_prompt_response_formatted_example(
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)

return tokenizer(text=prompt, text_target=response)
return _tokenize_with_bos_removal(
tokenizer=tokenizer,
text=prompt,
text_target=response,
)


def tokenize_formatted_example(
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,30 @@ def test_tokenize_instruct_example_well_formed():
tokenized_example = tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


def test_tokenize_no_labels_bos_pr():
# This tokenizer automatically adds bos tokens
tokenizer = transformers.AutoTokenizer.from_pretrained(
'mistralai/Mixtral-8x7B-v0.1')

dakinggg marked this conversation as resolved.
Show resolved Hide resolved
example = {'prompt': 'prompt', 'response': 'response'}

assert tokenizer.add_bos_token == True

tokenized_example = tokenize_formatted_example(example, tokenizer)

assert len(tokenized_example['labels']) == 1
assert tokenized_example['labels'][0] != tokenizer.bos_token_id
assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id

# This tokenizer does not have the add_bos_token attribute
tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b')

assert not hasattr(tokenizer, 'add_bos_token')

tokenized_example = tokenize_formatted_example(example, tokenizer)

assert len(tokenized_example['labels']) == 1
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
assert tokenized_example['labels'][0] != tokenizer.bos_token_id
assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id
Loading