diff --git a/trinity/common/config.py b/trinity/common/config.py index 982d6049b3..dca1b22e9d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -668,9 +668,15 @@ def _check_buffer(self) -> None: # noqa: C901 from transformers import AutoTokenizer try: - self.buffer.pad_token_id = AutoTokenizer.from_pretrained( - self.model.model_path - ).pad_token_id + tokenizer = AutoTokenizer.from_pretrained(self.model.model_path) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + logger.warning( + f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}", + stacklevel=1, + ) + self.buffer.pad_token_id = tokenizer.pad_token_id + except Exception: logger.warning(f"Failed to get pad token id from model {self.model.model_path}") self.buffer.pad_token_id = 0