Skip to content

Commit

Permalink
Merge pull request #88 from jianzhnie/dev
Browse files Browse the repository at this point in the history
update chatllms
  • Loading branch information
jianzhnie authored Aug 15, 2023
2 parents fbcb535 + 5e78fb7 commit 6fab54c
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion chatllms/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer
from transformers.trainer_utils import get_last_checkpoint

from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)

Expand Down Expand Up @@ -273,6 +274,21 @@ def find_last_checkpoint(checkpoint_dir):
last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
return last_checkpoint

# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores


def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor


def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
Expand Down

0 comments on commit 6fab54c

Please sign in to comment.