-
Notifications
You must be signed in to change notification settings - Fork 535
[BUGFIX]fix bug of top-p sampling #1503
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -717,11 +717,20 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask, | |
probs = mx.npx.softmax(outputs / self._temperature) | ||
|
||
if self._sampling_topp > 0: | ||
probs = mx.np.where( | ||
probs > self._sampling_topp, | ||
probs, | ||
mx.np.zeros_like(probs) | ||
) | ||
sorted_probs, sorted_indices = mx.npx.topk(probs, axis=2, k=-1, ret_typ='both', is_ascend=False) | ||
cumsum_probs = mx.np.cumsum(sorted_probs, axis=2) | ||
masked_probs = mx.np.where( | ||
cumsum_probs > self._sampling_topp, | ||
sorted_probs, | ||
mx.np.zeros_like(probs) | ||
) | ||
# choose the borderline prob | ||
p_prob = mx.np.min(masked_probs, axis=2, keepdims=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use exactly the same implementation as https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm referring to the part in which they choose not to mask the top-1 probability:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the confusion. I see that both sort and argsort are implemented but I don't see a way to get both values and indices in one call. The usage of |
||
probs = mx.np.where( | ||
probs >= p_prob, | ||
probs, | ||
mx.np.zeros_like(probs) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The major difference between the current implementation and the original pytorch-based implementation is that when The pytorch-based implementation will always choose the token that is most probable. |
||
elif self._sampling_topk > 0: | ||
topk_probs = mx.npx.topk(probs, axis=2, k=self._sampling_topk, ret_typ='value') | ||
# choose the k max prob | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think previously we have the results of
t=0.9
, we should remove that row.