Skip to content

Commit

Permalink
Merge pull request #21 from probcomp/gg/good_tokens
Browse files Browse the repository at this point in the history
Fix ValueError for `LMTokenMask.log_prob()` when the mask has no support
  • Loading branch information
alex-lew authored Jan 15, 2025
2 parents a191fca + ec6b35a commit 7c68968
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ async def log_prob(self, v):
if v
else self.ctx.model_mask - self.mask
)
if len(good_tokens) == 0:
# If there are no good tokens, the log probability of v under the mask is -inf
logprob_good = float("-inf")
else:
logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])

bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens]
logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])
self.ctx.next_token_logprobs[bad_tokens] = float("-inf")
self.ctx.next_token_logprobs -= logprob_good
self.ctx.model_mask = good_tokens
Expand Down

0 comments on commit 7c68968

Please sign in to comment.