Skip to content

Commit

Permalink
set self.ctx.model_mask to empty set when len(good_tokens) == 0
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegrand committed Jan 6, 2025
1 parent 582de08 commit ec6b35a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ async def log_prob(self, v):
)
if len(good_tokens) == 0:
# If there are no good tokens, the log probability of v under the mask is -inf
return float("-inf")
logprob_good = float("-inf")
else:
logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])

bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens]
logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])
self.ctx.next_token_logprobs[bad_tokens] = float("-inf")
self.ctx.next_token_logprobs -= logprob_good
self.ctx.model_mask = good_tokens
Expand Down

0 comments on commit ec6b35a

Please sign in to comment.