Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:jianzhnie/Efficient-Tuning-LLMs into…
Browse files Browse the repository at this point in the history
… dev
  • Loading branch information
jianzhnie committed Sep 11, 2023
2 parents bfd1207 + 5e78fb7 commit 4966498
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion chatllms/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from transformers.trainer_utils import get_last_checkpoint

from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)

Expand Down Expand Up @@ -275,6 +276,21 @@ def find_last_checkpoint(checkpoint_dir):
last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
return last_checkpoint

# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores


def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor


# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
Expand Down

0 comments on commit 4966498

Please sign in to comment.