Skip to content

Commit 39e2f19

Browse files
committed
formatting
1 parent cd63932 commit 39e2f19

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

flashinfer_bench/bench/evaluators/sampling.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,15 @@ def build_baseline(
5353
outputs: List[Dict[str, torch.Tensor]] = []
5454

5555
inp = gen_inputs(defn, workload, device=device, stensors=loaded_stensors)
56-
if "probs" in inp:
57-
inp["probs"] = torch.softmax(
58-
inp["probs"], dim=-1
59-
) # convert logits to probs for sampling
6056
inputs.append(inp)
6157

6258
thresholding_method = _detect_thresholding_method(defn)
6359
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
6460
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)
65-
61+
6662
masked_probs = inp["probs"] * valid_mask.float()
6763
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
68-
64+
6965
outputs.append({"expected_probs": expected_probs})
7066

7167
latencies: List[float] = []
@@ -151,7 +147,7 @@ def check_correctness(
151147
samples_flat = samples.unsqueeze(0)
152148
else:
153149
samples_flat = samples.flatten()
154-
150+
155151
batch_size = valid_mask.shape[0]
156152
for i in range(len(samples_flat)):
157153
batch_idx = i % batch_size
@@ -160,9 +156,7 @@ def check_correctness(
160156
correctness = Correctness(
161157
max_relative_error=float("inf"), max_absolute_error=float("inf")
162158
)
163-
message = (
164-
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
165-
)
159+
message = f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
166160
print(message, file=sys.stderr)
167161
return correctness, make_eval(
168162
status=EvaluationStatus.INCORRECT_NUMERICAL,
@@ -186,11 +180,11 @@ def check_correctness(
186180
tvds = []
187181
max_abs_errors = []
188182
max_rel_errors = []
189-
183+
190184
for i in range(batch_size):
191185
tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item()
192186
tvds.append(tvd_i)
193-
187+
194188
max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg)
195189
max_abs_errors.append(max_abs_i)
196190
max_rel_errors.append(max_rel_i)
@@ -202,9 +196,9 @@ def check_correctness(
202196

203197
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
204198
correctness = Correctness(
205-
max_relative_error=max_rel,
206-
max_absolute_error=max_abs,
207-
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
199+
max_relative_error=max_rel,
200+
max_absolute_error=max_abs,
201+
extra={"tvd": max_tvd, "tvds_per_batch": tvds},
208202
)
209203
if numerical_incorrect:
210204
return correctness, make_eval(
@@ -234,15 +228,15 @@ def _detect_thresholding_method(defn: Definition) -> str:
234228

235229

236230
def _compute_valid_sampling_mask(
237-
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 1e-5
231+
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 5e-2
238232
) -> torch.Tensor:
239233
"""
240234
For tie-breaking in top_k (allows any token with prob >= k-th largest)
241235
and numerical precision in top_p (allows tokens within eps of nucleus boundary).
242236
"""
243237
if probs.dim() == 1:
244238
probs = probs.unsqueeze(0)
245-
239+
246240
batch_size, vocab_size = probs.shape
247241
device = probs.device
248242

@@ -254,48 +248,40 @@ def _compute_valid_sampling_mask(
254248
if method in ["top_k", "top_k_top_p"]:
255249
if "top_k" not in params:
256250
raise ValueError(f"top_k parameter required for {method} but not found")
257-
251+
258252
top_k_param = params["top_k"]
259253
for i in range(batch_size):
260-
k = (
261-
int(top_k_param[i].item())
262-
if top_k_param.dim() > 0
263-
else int(top_k_param.item())
264-
)
254+
k = int(top_k_param[i].item()) if top_k_param.dim() > 0 else int(top_k_param.item())
265255

266256
if 0 < k < vocab_size:
267257
sorted_probs, _ = torch.sort(probs[i], descending=True)
268258
# k-th largest value (0-indexed, so k-1)
269259
pivot = sorted_probs[k - 1]
270-
mask[i] = probs[i] >= pivot # tie-breaking handling
260+
mask[i] = probs[i] >= pivot # tie-breaking handling
271261

272262
# Apply top_p mask with epsilon tolerance
273263
if method in ["top_p", "top_k_top_p"]:
274264
if "top_p" not in params:
275265
raise ValueError(f"top_p parameter required for {method} but not found")
276-
266+
277267
top_p_param = params["top_p"]
278268
for i in range(batch_size):
279-
p = (
280-
float(top_p_param[i].item())
281-
if top_p_param.dim() > 0
282-
else float(top_p_param.item())
283-
)
269+
p = float(top_p_param[i].item()) if top_p_param.dim() > 0 else float(top_p_param.item())
284270

285271
if 0 < p < 1:
286272
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
287273
cumsum = torch.cumsum(sorted_probs, dim=0)
288-
274+
289275
# Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
290276
nucleus_mask = cumsum <= (p + eps)
291-
277+
292278
if not nucleus_mask.any():
293279
nucleus_mask[0] = True
294-
280+
295281
# Map back to original indices
296282
p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
297283
p_mask[sorted_indices[nucleus_mask]] = True
298-
284+
299285
mask[i] = mask[i] & p_mask
300286

301287
return mask
@@ -310,12 +296,12 @@ def _sample_token_distributions(
310296
) -> torch.Tensor:
311297
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
312298
vocab_size = inputs["probs"].shape[-1]
313-
299+
314300
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
315301
target_batch_size = 10000
316302
repeat_count = target_batch_size // original_batch_size
317303
actual_batch_size = repeat_count * original_batch_size
318-
304+
319305
padded_inputs = {}
320306
for key, value in inputs.items():
321307
if isinstance(value, torch.Tensor) and value.dim() > 0:
@@ -341,12 +327,12 @@ def _sample_token_distributions(
341327
else:
342328
# For non-tensor inputs, keep as is
343329
padded_inputs[key] = value
344-
330+
345331
counters = torch.zeros(
346332
(original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device)
347333
)
348334

349-
trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size
335+
trials_needed = (num_trials + repeat_count - 1) // repeat_count
350336
total_samples_per_batch = 0
351337

352338
for _ in range(trials_needed):

0 commit comments

Comments
 (0)