diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 126ed43812..8faddb2825 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -195,6 +195,29 @@ def _slice_chat_formatted_example( return prompt, response +def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, + text_target: str) -> TokenizedExample: + """Tokenizes the prompt and response using the provided tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization. + prompt (str): The prompt to tokenize. + response (str): The response to tokenize. + + Returns: + TokenizedExample: The tokenized example. + """ + tokenized_sample = tokenizer(text=text, text_target=text_target) + + # Remove the BOS token from the start of the labels if it was automatically added + if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: + if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ + 0] == tokenizer.bos_token_id: + tokenized_sample['labels'] = tokenized_sample['labels'][1:] + + return tokenized_sample + + def _tokenize_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: @@ -246,7 +269,11 @@ def _tokenize_prompt_response_formatted_example( f'Unable to tokenize example because {response_key} was not a string. {example=}' ) - return tokenizer(text=prompt, text_target=response) + return _tokenize_with_bos_removal( + tokenizer=tokenizer, + text=prompt, + text_target=response, + ) def tokenize_formatted_example( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 5491b94521..fdaf30ccc5 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -163,3 +163,30 @@ def test_tokenize_instruct_example_well_formed(): tokenized_example = tokenize_formatted_example(example, tokenizer) assert 'input_ids' in tokenized_example assert 'labels' in tokenized_example + + +def test_tokenize_no_labels_bos_pr(): + # This tokenizer automatically adds bos tokens + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mistralai/Mixtral-8x7B-v0.1') + + example = {'prompt': 'prompt', 'response': 'response'} + + assert tokenizer.add_bos_token == True + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 1 + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id + + # This tokenizer does not have the add_bos_token attribute + tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') + + assert not hasattr(tokenizer, 'add_bos_token') + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 1 + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id