|
20 | 20 |
|
21 | 21 | import torch |
22 | 22 |
|
23 | | -from vllm.v1.sample.ops.topk_topp_sampler import \ |
24 | | - apply_top_k_top_p # noqa: F401 |
25 | 23 | from vllm.v1.sample.sampler import Sampler # noqa: F401 |
26 | 24 |
|
27 | 25 | # Set tolerance to 1 for quant ops |
@@ -51,6 +49,49 @@ def apply_min_p_new( |
51 | 49 | return logits |
52 | 50 |
|
53 | 51 |
|
| 52 | +def apply_top_k_top_p( |
| 53 | + logits: torch.Tensor, |
| 54 | + k: Optional[torch.Tensor], |
| 55 | + p: Optional[torch.Tensor], |
| 56 | +) -> torch.Tensor: |
| 57 | + """Apply top-k and top-p masks to the logits. |
| 58 | +
|
| 59 | + If a top-p is used, this function will sort the logits tensor, |
| 60 | + which can be slow for large batches. |
| 61 | +
|
| 62 | + The logits tensor may be updated in-place. |
| 63 | + """ |
| 64 | + if p is None: |
| 65 | + if k is None: |
| 66 | + return logits |
| 67 | + |
| 68 | + # Avoid sorting vocab for top-k only case. |
| 69 | + return apply_top_k_only(logits, k) |
| 70 | + |
| 71 | + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) |
| 72 | + |
| 73 | + if k is not None: |
| 74 | + # Apply top-k. |
| 75 | + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B |
| 76 | + # Get all the top_k values. |
| 77 | + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) |
| 78 | + top_k_mask = logits_sort < top_k_mask |
| 79 | + logits_sort.masked_fill_(top_k_mask, -float("inf")) |
| 80 | + |
| 81 | + if p is not None: |
| 82 | + # Apply top-p. |
| 83 | + probs_sort = logits_sort.softmax(dim=-1) |
| 84 | + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) |
| 85 | + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) |
| 86 | + # at least one |
| 87 | + top_p_mask[:, -1] = False |
| 88 | + logits_sort.masked_fill_(top_p_mask, -float("inf")) |
| 89 | + |
| 90 | + # Re-sort the probabilities. |
| 91 | + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) |
| 92 | + return logits |
| 93 | + |
| 94 | + |
54 | 95 | def apply_top_k_top_p_new( |
55 | 96 | logits: torch.Tensor, |
56 | 97 | k: Optional[torch.Tensor], |
|
0 commit comments