diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index b1aaeb350..f30d0e383 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -6,7 +6,7 @@ import logging from typing import Dict, List, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import IGNORE_TOKEN_ID @@ -66,14 +66,21 @@ def _get_assistant_token(self): pass return False - def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) + def _tokenize( + self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False + ) -> BatchEncoding: + result: BatchEncoding + if not prompt.strip(): + LOG.warning("Empty text requested for tokenization.") + result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + else: + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") if (