Skip to content

Commit 18a032e

Browse files
committed
refactor: update sampling evaluation logic
1 parent 20a2870 commit 18a032e

File tree

1 file changed

+188
-153
lines changed

1 file changed

+188
-153
lines changed

flashinfer_bench/bench/evaluators/sampling.py

Lines changed: 188 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,14 @@ def build_baseline(
5959
) # convert logits to probs for sampling
6060
inputs.append(inp)
6161

62-
freq_dist = _compute_frequency_distribution(
63-
ref_runnable, inp, device, defn, num_trials=50000
64-
)
65-
outputs.append({"frequency_distribution": freq_dist})
62+
thresholding_method = _detect_thresholding_method(defn)
63+
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
64+
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)
65+
66+
masked_probs = inp["probs"] * valid_mask.float()
67+
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
68+
69+
outputs.append({"expected_probs": expected_probs})
6670

6771
latencies: List[float] = []
6872
for inp in inputs:
@@ -94,15 +98,20 @@ def check_correctness(
9498
log_path: str,
9599
device: str,
96100
) -> Tuple[Optional[Correctness], Optional[Evaluation]]:
97-
ref_freq = ref_outputs[0]["frequency_distribution"]
98-
vocab_size = ref_freq.shape[0]
101+
expected_probs = ref_outputs[0]["expected_probs"]
102+
vocab_size = expected_probs.shape[-1]
99103

100104
inp = inputs[0]
101105
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
102106

103107
output_names = list(defn.outputs.keys())
104108
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
105109

110+
# Compute valid sampling mask based on thresholding
111+
thresholding_method = _detect_thresholding_method(defn)
112+
probs = inp["probs"]
113+
valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params)
114+
106115
# Validate correct sampling token set
107116
for _ in range(cfg.sampling_validation_trials):
108117
try:
@@ -137,27 +146,34 @@ def check_correctness(
137146
correctness=correctness,
138147
)
139148

140-
# Validate thresholding
141-
thresholding_method = _detect_thresholding_method(defn)
142-
probs = inp["probs"]
143-
if not _check_thresholding(samples, probs, thresholding_method, params):
144-
correctness = Correctness(
145-
max_relative_error=float("inf"), max_absolute_error=float("inf")
146-
)
147-
message = (
148-
f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding"
149-
)
150-
print(message, file=sys.stderr)
151-
return correctness, make_eval(
152-
status=EvaluationStatus.INCORRECT_NUMERICAL,
153-
device=device,
154-
log_path=log_path,
155-
correctness=correctness,
156-
)
149+
# Validate thresholding - check samples are within valid mask
150+
if samples.dim() == 0:
151+
samples_flat = samples.unsqueeze(0)
152+
else:
153+
samples_flat = samples.flatten()
154+
155+
batch_size = valid_mask.shape[0]
156+
for i in range(len(samples_flat)):
157+
batch_idx = i % batch_size
158+
sample_idx = samples_flat[i].item()
159+
if not valid_mask[batch_idx, sample_idx]:
160+
correctness = Correctness(
161+
max_relative_error=float("inf"), max_absolute_error=float("inf")
162+
)
163+
message = (
164+
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
165+
)
166+
print(message, file=sys.stderr)
167+
return correctness, make_eval(
168+
status=EvaluationStatus.INCORRECT_NUMERICAL,
169+
device=device,
170+
log_path=log_path,
171+
correctness=correctness,
172+
)
157173

158174
try:
159-
sol_freq = _compute_frequency_distribution(
160-
sol_runnable, inp, device, defn, num_trials=50000
175+
sol_freqs = _sample_token_distributions(
176+
sol_runnable, inp, device, defn, num_trials=500000
161177
)
162178
torch.cuda.synchronize(device)
163179
except Exception:
@@ -166,13 +182,29 @@ def check_correctness(
166182
status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path
167183
)
168184

169-
# total variation distance
170-
tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item()
171-
max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg)
172-
173-
numerical_incorrect = tvd > cfg.sampling_tvd_threshold
185+
batch_size = expected_probs.shape[0]
186+
tvds = []
187+
max_abs_errors = []
188+
max_rel_errors = []
189+
190+
for i in range(batch_size):
191+
tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item()
192+
tvds.append(tvd_i)
193+
194+
max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg)
195+
max_abs_errors.append(max_abs_i)
196+
max_rel_errors.append(max_rel_i)
197+
198+
# Use the worst (max) TVD and errors across all batch elements
199+
max_tvd = max(tvds)
200+
max_abs = max(max_abs_errors)
201+
max_rel = max(max_rel_errors)
202+
203+
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
174204
correctness = Correctness(
175-
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
205+
max_relative_error=max_rel,
206+
max_absolute_error=max_abs,
207+
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
176208
)
177209
if numerical_incorrect:
178210
return correctness, make_eval(
@@ -201,23 +233,125 @@ def _detect_thresholding_method(defn: Definition) -> str:
201233
return "none" # no thresholding
202234

203235

204-
def _compute_frequency_distribution(
236+
def _compute_valid_sampling_mask(
237+
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 1e-5
238+
) -> torch.Tensor:
239+
"""
240+
For tie-breaking in top_k (allows any token with prob >= k-th largest)
241+
and numerical precision in top_p (allows tokens within eps of nucleus boundary).
242+
"""
243+
if probs.dim() == 1:
244+
probs = probs.unsqueeze(0)
245+
246+
batch_size, vocab_size = probs.shape
247+
device = probs.device
248+
249+
if method == "none":
250+
return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)
251+
252+
mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)
253+
254+
if method in ["top_k", "top_k_top_p"]:
255+
if "top_k" not in params:
256+
raise ValueError(f"top_k parameter required for {method} but not found")
257+
258+
top_k_param = params["top_k"]
259+
for i in range(batch_size):
260+
k = (
261+
int(top_k_param[i].item())
262+
if top_k_param.dim() > 0
263+
else int(top_k_param.item())
264+
)
265+
266+
if 0 < k < vocab_size:
267+
sorted_probs, _ = torch.sort(probs[i], descending=True)
268+
# k-th largest value (0-indexed, so k-1)
269+
pivot = sorted_probs[k - 1]
270+
mask[i] = probs[i] >= pivot # tie-breaking handling
271+
272+
# Apply top_p mask with epsilon tolerance
273+
if method in ["top_p", "top_k_top_p"]:
274+
if "top_p" not in params:
275+
raise ValueError(f"top_p parameter required for {method} but not found")
276+
277+
top_p_param = params["top_p"]
278+
for i in range(batch_size):
279+
p = (
280+
float(top_p_param[i].item())
281+
if top_p_param.dim() > 0
282+
else float(top_p_param.item())
283+
)
284+
285+
if 0 < p < 1:
286+
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
287+
cumsum = torch.cumsum(sorted_probs, dim=0)
288+
289+
# Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
290+
nucleus_mask = cumsum <= (p + eps)
291+
292+
if not nucleus_mask.any():
293+
nucleus_mask[0] = True
294+
295+
# Map back to original indices
296+
p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
297+
p_mask[sorted_indices[nucleus_mask]] = True
298+
299+
mask[i] = mask[i] & p_mask
300+
301+
return mask
302+
303+
304+
def _sample_token_distributions(
205305
runnable: Runnable,
206306
inputs: Dict[str, Any],
207307
device: str,
208308
defn: Definition,
209-
num_trials: int = 10000,
309+
num_trials: int = 500000,
210310
) -> torch.Tensor:
211-
batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
311+
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
212312
vocab_size = inputs["probs"].shape[-1]
213-
counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device))
214-
215-
trials_needed = (num_trials + batch_size - 1) // batch_size
216-
total_samples_collected = 0
313+
314+
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
315+
target_batch_size = 10000
316+
repeat_count = target_batch_size // original_batch_size
317+
actual_batch_size = repeat_count * original_batch_size
318+
319+
padded_inputs = {}
320+
for key, value in inputs.items():
321+
if isinstance(value, torch.Tensor) and value.dim() > 0:
322+
if key == "probs":
323+
# For probs, repeat the entire batch
324+
if value.dim() == 1:
325+
value = value.unsqueeze(0)
326+
# Repeat the entire batch repeat_count times
327+
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
328+
elif key in ["top_k", "top_p"]:
329+
# For sampling parameters, repeat the entire batch
330+
if value.dim() == 0:
331+
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
332+
else:
333+
padded_value = value.repeat(repeat_count)
334+
else:
335+
# For other tensors, repeat entire batch along batch dimension
336+
if value.dim() == 0:
337+
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
338+
else:
339+
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
340+
padded_inputs[key] = padded_value
341+
else:
342+
# For non-tensor inputs, keep as is
343+
padded_inputs[key] = value
344+
345+
counters = torch.zeros(
346+
(original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device)
347+
)
348+
349+
trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size
350+
total_samples_per_batch = 0
217351

218352
for _ in range(trials_needed):
219353
with torch.no_grad():
220-
out = runnable(**inputs)
354+
out = runnable(**padded_inputs)
221355

222356
output_names = list(defn.outputs.keys())
223357
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
@@ -229,118 +363,19 @@ def _compute_frequency_distribution(
229363
samples = out_normalized["samples"]
230364

231365
if samples.dim() == 0:
366+
# Single sample - assign to first batch element
232367
sample_idx = samples.item()
233-
counter[sample_idx] += 1
234-
total_samples_collected += 1
235-
else: # Batch of samples
236-
for i in range(samples.numel()):
237-
sample_idx = samples.flatten()[i].item()
238-
counter[sample_idx] += 1
239-
total_samples_collected += 1
240-
241-
frequency = counter.float() / total_samples_collected
242-
return frequency
243-
244-
245-
def _check_thresholding(
246-
samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any]
247-
) -> bool:
248-
"""Check if samples conform to the specified thresholding method.
249-
250-
Parameters
251-
----------
252-
samples : torch.Tensor
253-
Sampled token indices.
254-
probs : torch.Tensor
255-
Probability distribution used for sampling.
256-
method : str
257-
Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
258-
params : Dict[str, Any]
259-
Sampling parameters (top_k, top_p values).
260-
261-
Returns
262-
-------
263-
bool
264-
True if samples are valid, False otherwise.
265-
"""
266-
batch_size, vocab_size = probs.shape
267-
device = probs.device
268-
269-
for i in range(batch_size):
270-
prob_row = probs[i]
271-
sample = samples[i].item()
272-
273-
if method == "top_k":
274-
if "top_k" not in params:
275-
raise ValueError("top_k parameter is required for top_k thresholding but not found")
276-
k = (
277-
int(params["top_k"][i].item())
278-
if params["top_k"].dim() > 0
279-
else int(params["top_k"].item())
280-
)
281-
282-
if 0 < k < vocab_size:
283-
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
284-
pivot = sorted_prob_desc[k - 1]
285-
mask_top_k = (prob_row >= pivot).int()
286-
if mask_top_k[sample] != 1:
287-
return False
288-
289-
elif method == "top_p":
290-
if "top_p" not in params:
291-
raise ValueError("top_p parameter is required for top_p thresholding but not found")
292-
p = (
293-
float(params["top_p"][i].item())
294-
if params["top_p"].dim() > 0
295-
else float(params["top_p"].item())
296-
)
297-
298-
if 0 < p < 1:
299-
eps = 1e-4 # numerical stability
300-
sorted_probs, indices = torch.sort(prob_row, descending=False)
301-
cdf = torch.cumsum(sorted_probs, dim=0)
302-
valid_mask = cdf > (1 - p) - eps
303-
valid_indices = indices[valid_mask]
304-
305-
if sample not in valid_indices:
306-
return False
307-
308-
elif method == "top_k_top_p":
309-
if "top_k" not in params or "top_p" not in params:
310-
raise ValueError(
311-
"top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
312-
)
313-
k = (
314-
int(params["top_k"][i].item())
315-
if params["top_k"].dim() > 0
316-
else int(params["top_k"].item())
317-
)
318-
p = (
319-
float(params["top_p"][i].item())
320-
if params["top_p"].dim() > 0
321-
else float(params["top_p"].item())
322-
)
323-
324-
if 0 < k < vocab_size:
325-
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
326-
pivot = sorted_prob_desc[k - 1]
327-
mask_top_k = (prob_row >= pivot).int()
328-
else:
329-
mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device)
330-
331-
if 0 < p < 1:
332-
eps = 1e-4
333-
sorted_probs_asc, indices = torch.sort(prob_row, descending=False)
334-
cdf = torch.cumsum(sorted_probs_asc, dim=0)
335-
mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device)
336-
valid_p_mask = cdf > (1 - p) - eps
337-
mask_top_p[indices[valid_p_mask]] = 1
338-
else:
339-
mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device)
340-
341-
joint_mask = torch.minimum(mask_top_k, mask_top_p)
342-
343-
if joint_mask[sample] != 1:
344-
return False
345-
346-
return True
368+
counters[0, sample_idx] += 1
369+
total_samples_per_batch += 1
370+
else:
371+
# slice and accumulate per original batch element
372+
samples_flat = samples.flatten()
373+
for i in range(samples_flat.numel()):
374+
batch_idx = i % original_batch_size
375+
sample_idx = samples_flat[i].item()
376+
counters[batch_idx, sample_idx] += 1
377+
total_samples_per_batch += repeat_count
378+
379+
# [batch_size, vocab_size]
380+
frequencies = counters.float() / total_samples_per_batch
381+
return frequencies

0 commit comments

Comments
 (0)