diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py index defb2f5..a053f5d 100644 --- a/chatllms/utils/model_utils.py +++ b/chatllms/utils/model_utils.py @@ -6,6 +6,8 @@ import bitsandbytes as bnb import torch from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList from transformers.trainer_utils import get_last_checkpoint from transformers.generation.logits_process import LogitsProcessor from transformers.generation.utils import LogitsProcessorList @@ -290,6 +292,22 @@ def get_logits_processor() -> LogitsProcessorList: return logits_processor +# Avoid runtime error in model.generate(do_sample=True). +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 0] = 1.0 + return scores + + +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + return logits_processor + + def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict()