Skip to content

Commit 77b79f7

Browse files
author
wangxiaoxin-sherie
committed
cleancode
1 parent 86eaf0c commit 77b79f7

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

tests/singlecard/test_offline_inference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_multimodal(model, prompt_template, vllm_runner):
8383
images=images,
8484
max_tokens=64)
8585

86+
8687
@pytest.mark.parametrize("model", MODELS)
8788
@pytest.mark.parametrize("dtype", ["half", "float16"])
8889
@pytest.mark.parametrize("max_tokens", [5])
@@ -94,7 +95,10 @@ def test_models_topk(model: str, dtype: str, max_tokens: int) -> None:
9495
"The capital of France is",
9596
"The future of AI is",
9697
]
97-
sampling_params = SamplingParams(max_tokens = max_tokens, temperature = 0.0, top_k = 50, top_p = 0.9)
98+
sampling_params = SamplingParams(max_tokens=max_tokens,
99+
temperature=0.0,
100+
top_k=50,
101+
top_p=0.9)
98102

99103
with VllmRunner(model,
100104
max_model_len=8192,

tests/singlecard/test_sampler.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
import torch
2222

23-
from vllm.v1.sample.ops.topk_topp_sampler import \
24-
apply_top_k_top_p # noqa: F401
2523
from vllm.v1.sample.sampler import Sampler # noqa: F401
2624

2725
# Set tolerance to 1 for quant ops
@@ -51,6 +49,49 @@ def apply_min_p_new(
5149
return logits
5250

5351

52+
def apply_top_k_top_p(
53+
logits: torch.Tensor,
54+
k: Optional[torch.Tensor],
55+
p: Optional[torch.Tensor],
56+
) -> torch.Tensor:
57+
"""Apply top-k and top-p masks to the logits.
58+
59+
If a top-p is used, this function will sort the logits tensor,
60+
which can be slow for large batches.
61+
62+
The logits tensor may be updated in-place.
63+
"""
64+
if p is None:
65+
if k is None:
66+
return logits
67+
68+
# Avoid sorting vocab for top-k only case.
69+
return apply_top_k_only(logits, k)
70+
71+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
72+
73+
if k is not None:
74+
# Apply top-k.
75+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
76+
# Get all the top_k values.
77+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
78+
top_k_mask = logits_sort < top_k_mask
79+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
80+
81+
if p is not None:
82+
# Apply top-p.
83+
probs_sort = logits_sort.softmax(dim=-1)
84+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
85+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
86+
# at least one
87+
top_p_mask[:, -1] = False
88+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
89+
90+
# Re-sort the probabilities.
91+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
92+
return logits
93+
94+
5495
def apply_top_k_top_p_new(
5596
logits: torch.Tensor,
5697
k: Optional[torch.Tensor],

0 commit comments

Comments
 (0)