Skip to content

Conversation

@zanderjiang
Copy link
Collaborator

@zanderjiang zanderjiang commented Oct 28, 2025

This PR updates several issues with the previous sampling evaluation logic:
The previous version compresses all input probs into single dim frequencies, this introduces vulnerabilities when input tensor's batchsize > 1, this PR addresses by retaining the input shape for sampled token distributions.

For sampled tokens, we compute per input probability distribution TVD against the ground truth. The Evaluation class will record the worst (max) TVD amongst all input batch elements.

To reduce correctness sampling iterations, we repeat the original input tensor for 10,000 // original_batch_size times, this still allows us to sample the non-deterministic kernel while running fewer forward passes to reduce benchmarking time.

Summary by CodeRabbit

  • New Features

    • Baseline now reports threshold-aware expected probabilities and uses them for downstream checks
    • Added public helpers to support threshold-aware sampling and validation
  • Bug Fixes

    • More accurate detection of valid tokens under top-k/top-p with tie and tolerance handling
    • Enhanced per-batch error reporting including TVD, max absolute and relative errors
  • Refactor

    • Increased sampling trials and optimized batched sampling for stronger statistical validation
    • Streamlined validation flow to apply masks per batch element

@coderabbitai
Copy link

coderabbitai bot commented Oct 28, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replaces frequency-distribution baseline with thresholding-aware expected probabilities, adds per-batch valid-token mask computation (top_k/top_p/combined), increases sampling to 500k trials with batched handling, and validates samples using per-batch TVD and expanded error reporting.

Changes

Cohort / File(s) Summary
Sampling Evaluator Refactoring
flashinfer_bench/bench/evaluators/sampling.py
Replaced baseline frequency-distribution output with thresholding-aware expected_probs; added _compute_valid_sampling_mask(probs, method, params, eps) to produce per-batch boolean masks for top_k/top_p/combined (with tie and eps handling); added _sample_token_distributions(runnable, inputs, device, defn, num_trials=500000) to produce large-sample token distributions with batched repetition/padding; updated correctness flow to validate samples against masks, compute per-batch TVD and per-batch error stats, and report max TVD, max absolute error, and max relative error in results.

Sequence Diagram

sequenceDiagram
    participant Input as Input Probabilities
    participant Mask as Valid Sampling Mask
    participant Masked as Masked Probabilities
    participant Sampler as Sampler (500k trials)
    participant Validator as Correctness Validator
    participant Stats as Per-Batch Stats

    Input->>Mask: probs, method, params
    activate Mask
    Note over Mask: compute per-batch boolean mask\n(top_k / top_p / combined, tie & eps)
    Mask-->>Masked: mask
    deactivate Mask

    Input->>Masked: apply mask -> expected_probs
    activate Masked
    Masked-->>Sampler: expected_probs (batched, padded)
    deactivate Masked

    Sampler->>Sampler: generate large-sample counts\n(500,000 trials, batched)
    Sampler-->>Validator: sample frequency distributions

    Validator->>Validator: validate each sample\nagainst per-batch mask
    Validator-->>Stats: compute per-batch TVD, abs/rel errors
    activate Stats
    Note over Stats: aggregate max TVD, max errors
    Stats-->>Validator: metrics
    deactivate Stats

    Validator-->>Input: correctness payload\n(expected_probs, per-batch TVD list, max errors)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Review mask computation for top_k/top_p/tie-handling and eps behavior.
  • Verify batched repetition/padding and memory/performance in 500k-trial sampler.
  • Validate per-batch TVD calculation and correct aggregation of max metrics.
  • Inspect updated correctness payload shape and downstream consumers.

Poem

🐰
I hopped through masks and probability streams,
Counting five hundred thousand little dreams.
Per-batch TVD lights up the night,
Expected probs now land just right.
A whisper of carrots, data delight.

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'refactor: update sampling evaluation logic' directly reflects the main changes: refactoring the sampling evaluation logic in the flashinfer_bench evaluators module to address batch size handling and improve validation accuracy.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch alex/sampling-refactor

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zanderjiang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the sampling evaluation logic to enhance its robustness and accuracy, particularly for scenarios involving batched inputs. The core changes involve moving from a single-dimension frequency distribution to per-input probability distributions for TVD calculation, ensuring that the evaluation correctly reflects the behavior of each batch element. It also introduces a more precise method for validating sampled tokens against top_k and top_p thresholding rules and optimizes the sampling process for benchmarking efficiency.

Highlights

  • Batched Input Handling: The evaluation logic now correctly handles input tensors with batch_size > 1 by retaining the input shape for sampled token distributions, addressing previous vulnerabilities.
  • Per-Input TVD Calculation: Total Variation Distance (TVD) is now computed for each input probability distribution against its ground truth, rather than a compressed single-dimension frequency.
  • Worst-Case Error Recording: The Evaluation class records the maximum (worst) TVD and errors observed across all elements within a batch, providing a more robust error metric.
  • Optimized Sampling Iterations: To reduce benchmarking time, the original input tensor is repeated multiple times (up to a target batch size of 10,000) to efficiently collect samples for non-deterministic kernels.
  • Improved Thresholding Validation: A new _compute_valid_sampling_mask function is introduced to accurately determine valid tokens for top_k and top_p sampling, including tie-breaking and numerical precision, replacing the previous _check_thresholding logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the sampling evaluation logic to correctly handle batched inputs and improve efficiency. The changes replace the previous frequency distribution computation with a more robust method that calculates expected probabilities and validates samples against a generated mask. The logic for collecting sample distributions is also updated to efficiently handle large numbers of trials by repeating inputs.

My review identifies two potential issues: a critical bug that could lead to a ZeroDivisionError when the input batch size is zero, and a high-severity issue with incorrect handling of scalar sample outputs in batch mode. I've provided suggestions to fix both of these to make the evaluation logic more robust.

Comment on lines 311 to 317
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
vocab_size = inputs["probs"].shape[-1]
counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device))

trials_needed = (num_trials + batch_size - 1) // batch_size
total_samples_collected = 0

# Repeat entire input batch to fill up to target_batch_size for efficient sampling
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential ZeroDivisionError here. If inputs["probs"] has a shape like (0, vocab_size), original_batch_size will be 0, causing a crash on line 316 when calculating repeat_count.

Additionally, if original_batch_size is larger than target_batch_size, repeat_count will be 0, leading to an actual_batch_size of 0. This will create 0-sized tensors and likely cause issues in the runnable.

I suggest handling the original_batch_size == 0 case explicitly and ensuring repeat_count is at least 1 to prevent these issues.

    original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
    vocab_size = inputs["probs"].shape[-1]

    if original_batch_size == 0:
        return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device))

    # Repeat entire input batch to fill up to target_batch_size for efficient sampling
    target_batch_size = 10000
    repeat_count = max(1, target_batch_size // original_batch_size)
    actual_batch_size = repeat_count * original_batch_size

Comment on lines 231 to +369
if samples.dim() == 0:
# Single sample - assign to first batch element
sample_idx = samples.item()
counter[sample_idx] += 1
total_samples_collected += 1
else: # Batch of samples
for i in range(samples.numel()):
sample_idx = samples.flatten()[i].item()
counter[sample_idx] += 1
total_samples_collected += 1

frequency = counter.float() / total_samples_collected
return frequency


def _check_thresholding(
samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any]
) -> bool:
"""Check if samples conform to the specified thresholding method.
Parameters
----------
samples : torch.Tensor
Sampled token indices.
probs : torch.Tensor
Probability distribution used for sampling.
method : str
Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
params : Dict[str, Any]
Sampling parameters (top_k, top_p values).
Returns
-------
bool
True if samples are valid, False otherwise.
"""
batch_size, vocab_size = probs.shape
device = probs.device

for i in range(batch_size):
prob_row = probs[i]
sample = samples[i].item()

if method == "top_k":
if "top_k" not in params:
raise ValueError("top_k parameter is required for top_k thresholding but not found")
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
if mask_top_k[sample] != 1:
return False

elif method == "top_p":
if "top_p" not in params:
raise ValueError("top_p parameter is required for top_p thresholding but not found")
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < p < 1:
eps = 1e-4 # numerical stability
sorted_probs, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs, dim=0)
valid_mask = cdf > (1 - p) - eps
valid_indices = indices[valid_mask]

if sample not in valid_indices:
return False

elif method == "top_k_top_p":
if "top_k" not in params or "top_p" not in params:
raise ValueError(
"top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
)
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
else:
mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device)

if 0 < p < 1:
eps = 1e-4
sorted_probs_asc, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs_asc, dim=0)
mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device)
valid_p_mask = cdf > (1 - p) - eps
mask_top_p[indices[valid_p_mask]] = 1
else:
mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device)

joint_mask = torch.minimum(mask_top_k, mask_top_p)

if joint_mask[sample] != 1:
return False

return True
counters[0, sample_idx] += 1
total_samples_per_batch += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When samples.dim() == 0, the code assumes a batch size of 1 and assigns the sample to the first batch element's counter. This is incorrect if original_batch_size > 1, as it would misattribute samples and lead to an incorrect frequency distribution for all batch items. The runnable should be expected to return a batch of samples matching actual_batch_size. If it returns a scalar when a batch is expected, it's a contract violation that should be flagged with an error.

        if samples.dim() == 0:
            if actual_batch_size != 1:
                raise ValueError(
                    f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
                )
            # Single sample - assign to first batch element
            sample_idx = samples.item()
            counters[0, sample_idx] += 1
            total_samples_per_batch += 1

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)

120-121: cuda synchronize called with a string; also guard for non‑CUDA devices.

torch.cuda.synchronize expects a CUDA device or index. Passing a str like "cuda:0" may fail; on CPU it always fails. Convert to torch.device and guard by type.

-                torch.cuda.synchronize(device)
+                _dev = torch.device(device)
+                if _dev.type == "cuda":
+                    torch.cuda.synchronize(_dev)
🧹 Nitpick comments (9)
flashinfer_bench/bench/evaluators/sampling.py (9)

127-129: normalize_outputs receives device=str here but torch.device elsewhere. Make consistent.

Pass torch.device(device) for consistency and to avoid downstream type assumptions.

-            out_normalized = normalize_outputs(
-                out, device=device, output_names=output_names, output_dtypes=output_dtypes
-            )
+            out_normalized = normalize_outputs(
+                out, device=torch.device(device),
+                output_names=output_names, output_dtypes=output_dtypes
+            )

224-234: Thresholding method detection should use params, not name.

Relying on defn.name can drift from runtime params; infer from presence/values of top_k/top_p.

-def _detect_thresholding_method(defn: Definition) -> str:
-    name = defn.name.lower()
-    if "top_k_top_p" in name:
-        return "top_k_top_p"
-    elif "top_k" in name:
-        return "top_k"
-    elif "top_p" in name:
-        return "top_p"
-    else:
-        return "none"  # no thresholding
+def _detect_thresholding_method(defn: Definition, params: Optional[Dict[str, Any]] = None) -> str:
+    params = params or {}
+    has_k = "top_k" in params and (int(params["top_k"].item()) if isinstance(params["top_k"], torch.Tensor) and params["top_k"].dim()==0 else True)
+    has_p = "top_p" in params
+    if has_k and has_p:
+        return "top_k_top_p"
+    if has_k:
+        return "top_k"
+    if has_p:
+        return "top_p"
+    return "none"

And update call sites to pass params.


254-271: Ruff TRY003: long exception messages.

Trim or refactor messages to constants to satisfy TRY003.

-            raise ValueError(f"top_k parameter required for {method} but not found")
+            raise ValueError("missing required parameter: top_k")
...
-            raise ValueError(f"top_p parameter required for {method} but not found")
+            raise ValueError("missing required parameter: top_p")

Also applies to: 273-276


349-351: total_samples_per_batch accounting is fragile for unusual outputs; compute from observed shapes.

Derive per‑iteration contribution from samples.numel() to avoid assumptions.

-    trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size
-    total_samples_per_batch = 0
+    trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size
+    total_samples_per_batch = 0
@@
-        else:
-            # slice and accumulate per original batch element
-            samples_flat = samples.flatten()
-            for i in range(samples_flat.numel()):
-                batch_idx = i % original_batch_size
-                sample_idx = samples_flat[i].item()
-                counters[batch_idx, sample_idx] += 1
-            total_samples_per_batch += repeat_count
+        else:
+            # slice and accumulate per original batch element
+            samples_flat = samples.flatten()
+            for i in range(samples_flat.numel()):
+                batch_idx = i % original_batch_size
+                sample_idx = int(samples_flat[i])
+                counters[batch_idx, sample_idx] += 1
+            total_samples_per_batch += samples_flat.numel() // original_batch_size

Also applies to: 380-381


352-361: Move invariant computations out of the sampling loop.

output_names/output_dtypes don’t change per iteration; compute once before the loop for speed.

-    for _ in range(trials_needed):
-        with torch.no_grad():
-            out = runnable(**padded_inputs)
-
-        output_names = list(defn.outputs.keys())
-        output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
+    output_names = list(defn.outputs.keys())
+    output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
+    for _ in range(trials_needed):
+        with torch.no_grad():
+            out = runnable(**padded_inputs)

365-377: Count samples with vectorized bincount per batch to reduce Python loops.

This greatly speeds up 500k trials.

-            samples_flat = samples.flatten()
-            for i in range(samples_flat.numel()):
-                batch_idx = i % original_batch_size
-                sample_idx = samples_flat[i].item()
-                counters[batch_idx, sample_idx] += 1
-            total_samples_per_batch += repeat_count
+            samples = samples.view(-1)  # [actual_batch_size]
+            # reshape to [repeat_count, original_batch_size] if divisible
+            if samples.numel() % original_batch_size == 0:
+                reshaped = samples.view(-1, original_batch_size)
+                for b in range(original_batch_size):
+                    counts = torch.bincount(reshaped[:, b], minlength=vocab_size)
+                    counters[b] += counts
+                total_samples_per_batch += reshaped.size(0)
+            else:
+                # fallback to scalar loop (rare)
+                for i in range(samples.numel()):
+                    counters[i % original_batch_size, int(samples[i])] += 1
+                total_samples_per_batch += samples.numel() // original_batch_size

352-355: Optional: synchronize inside the sampling loop only when using CUDA.

Keeps timing/state consistent if kernels are async.

-        with torch.no_grad():
-            out = runnable(**padded_inputs)
+        with torch.no_grad():
+            out = runnable(**padded_inputs)
+        _dev = torch.device(device)
+        if _dev.type == "cuda":
+            torch.cuda.synchronize(_dev)

52-60: Nit: normalize logits once; consider explicit dtype.

If inp["probs"] is already probs, softmax again can skew. Gate behind a flag or defn meta.

-        if "probs" in inp:
+        if "probs" in inp and defn.inputs.get("probs", {}).get("is_logits", True):
             inp["probs"] = torch.softmax(
                 inp["probs"], dim=-1
             )  # convert logits to probs for sampling

246-302: Mask construction looks correct; consider vectorizing top_k/top_p over batch.

Current per‑row loops are fine functionally; vectorization would simplify and speed up.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 20a2870 and 18a032e.

📒 Files selected for processing (1)
  • flashinfer_bench/bench/evaluators/sampling.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/utils.py (1)
  • dtype_str_to_torch_dtype (39-45)
flashinfer_bench/data/trace.py (2)
  • Correctness (72-95)
  • EvaluationStatus (126-146)
flashinfer_bench/bench/utils.py (2)
  • make_eval (261-279)
  • compute_error_stats (89-116)
flashinfer_bench/compile/runnable.py (1)
  • Runnable (6-38)
🪛 GitHub Actions: .github/workflows/linting.yaml
flashinfer_bench/bench/evaluators/sampling.py

[error] 1-1: trailing-whitespace: Hooks detected and fixed trailing whitespace. 1 file was modified by this hook.


[error] 1-1: black: Reformatted 1 file(s). All done! 1 file reformatted, 101 files left unchanged. Files were modified by this hook.

🪛 Ruff (0.14.1)
flashinfer_bench/bench/evaluators/sampling.py

256-256: Avoid specifying long messages outside the exception class

(TRY003)


275-275: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.9
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)

1-5: No issues found; review comment is accurate.

The file contains no trailing whitespace and is properly formatted per Black standards. The .pre-commit-config.yaml is configured with the appropriate hooks (trailing-whitespace, black, isort). The review comment correctly identifies that formatting has been applied and appropriately advises running hooks locally before pushing—this is informational guidance with no outstanding fixes needed.

Comment on lines 62 to 69
thresholding_method = _detect_thresholding_method(defn)
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)

masked_probs = inp["probs"] * valid_mask.float()
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)

outputs.append({"expected_probs": expected_probs})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against zero-sum masked probs to avoid NaNs.

If the valid_mask zeros out all tokens (edge params), masked_probs.sum can be 0 leading to NaNs in expected_probs. Clamp the denominator.

-        masked_probs = inp["probs"] * valid_mask.float()
-        expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
+        masked_probs = inp["probs"] * valid_mask.float()
+        denom = masked_probs.sum(dim=-1, keepdim=True)
+        # Avoid NaNs if no tokens survive; fall back to uniform over valid_mask
+        denom = torch.where(denom > 0, denom, torch.ones_like(denom))
+        expected_probs = masked_probs / denom
+        # If denom was 0, distribute uniformly across valid tokens
+        zero_rows = (masked_probs.sum(dim=-1, keepdim=True) == 0)
+        if zero_rows.any():
+            uniform = valid_mask.float() / valid_mask.float().sum(dim=-1, keepdim=True).clamp_min(1)
+            expected_probs = torch.where(zero_rows, uniform, expected_probs)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 62 to 69, the code
divides by masked_probs.sum which can be zero if valid_mask zeros out all
tokens; change the denominator to a clamped value to avoid NaNs by computing the
sum with keepdim=True and then applying .clamp_min(eps) (use a small constant
like 1e-12) before dividing so expected_probs = masked_probs / denom_clamped,
ensuring you preserve shapes and device when creating eps.

Comment on lines 159 to 166
if not valid_mask[batch_idx, sample_idx]:
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid truthiness on 0‑dim Torch bool Tensor.

Using if not valid_mask[...] can raise “Boolean value of Tensor is ambiguous.” Convert to Python bool.

-                if not valid_mask[batch_idx, sample_idx]:
+                if not bool(valid_mask[batch_idx, sample_idx].item()):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not valid_mask[batch_idx, sample_idx]:
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
if not bool(valid_mask[batch_idx, sample_idx].item()):
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 159 to 166, the
condition uses truthiness on a 0‑dim Torch bool Tensor (if not
valid_mask[batch_idx, sample_idx]) which can raise “Boolean value of Tensor is
ambiguous.” Convert the tensor to a Python bool by calling .item() (and .cpu()
if it may be on GPU) before negation, and use that boolean in the if check so
the branch evaluates correctly.

Comment on lines +175 to 179
sol_freqs = _sample_token_distributions(
sol_runnable, inp, device, defn, num_trials=500000
)
torch.cuda.synchronize(device)
except Exception:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Second cuda synchronize has same issue; apply the same guard.

Mirror the earlier fix after sampling distributions.

-            torch.cuda.synchronize(device)
+            _dev = torch.device(device)
+            if _dev.type == "cuda":
+                torch.cuda.synchronize(_dev)
🧰 Tools
🪛 Ruff (0.14.1)

179-179: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 175 to 179, the
second torch.cuda.synchronize(device) call needs the same protective guard as
the earlier synchronize to avoid raising on non-CUDA or unavailable CUDA
devices; wrap the synchronize in the same conditional/try-except used previously
(e.g., only call if device.type == "cuda" and torch.cuda.is_available(), or call
inside a try: ... except Exception: pass block) so any CUDA sync errors are
handled identically to the earlier fix.

Comment on lines 203 to 208
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
correctness = Correctness(
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Protect against missing cfg.sampling_tvd_threshold.

If cfg.sampling_tvd_threshold is None/absent, comparison raises. Default or fail fast with message.

-        numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
+        tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
+        if tvd_thresh is None:
+            raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
+        numerical_incorrect = max_tvd > tvd_thresh
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
correctness = Correctness(
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
if tvd_thresh is None:
raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
numerical_incorrect = max_tvd > tvd_thresh
correctness = Correctness(
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 203 to 208, the
code compares max_tvd against cfg.sampling_tvd_threshold without handling the
case where that config value is None or missing; add a guard before the
comparison that checks if cfg.sampling_tvd_threshold is None and fail fast with
a clear ValueError (or optionally set a documented default threshold) so the
comparison never raises a TypeError; update the subsequent comparison to use the
validated threshold variable.

Comment on lines 315 to 318
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

repeat_count can be 0 when original_batch_size > 10_000 → division by zero and empty batches.

Ensure at least one repeat; prefer ceil to keep high utilization.

-    target_batch_size = 10000
-    repeat_count = target_batch_size // original_batch_size
-    actual_batch_size = repeat_count * original_batch_size
+    import math
+    target_batch_size = 10000
+    repeat_count = max(1, math.ceil(target_batch_size / original_batch_size))
+    actual_batch_size = repeat_count * original_batch_size

Also consider making target_batch_size configurable.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size
import math
target_batch_size = 10000
repeat_count = max(1, math.ceil(target_batch_size / original_batch_size))
actual_batch_size = repeat_count * original_batch_size
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 315-318, the
current computation uses integer division target_batch_size //
original_batch_size which can yield 0 when original_batch_size > 10000 (leading
to empty batches) and underutilization; change to use a ceiling division and
ensure at least one repeat: compute repeat_count = max(1, ceil(target_batch_size
/ original_batch_size)) (or equivalent integer math), then set actual_batch_size
= repeat_count * original_batch_size, and expose target_batch_size as a
configurable parameter (with validation to be a positive int) so it can be tuned
instead of hardcoding 10000.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
flashinfer_bench/bench/evaluators/sampling.py (6)

172-172: Guard CUDA synchronization for non-CUDA devices.

torch.cuda.synchronize(device) will raise an error if device is not CUDA or if CUDA is unavailable. Add a guard to only synchronize for CUDA devices.

Apply this diff:

-            torch.cuda.synchronize(device)
+            _dev = torch.device(device)
+            if _dev.type == "cuda":
+                torch.cuda.synchronize(_dev)

351-355: Validate runnable contract when receiving scalar samples.

If samples.dim() == 0 (scalar) but actual_batch_size > 1, the runnable has violated its contract by returning a single sample when a batch was expected. The current code silently assigns this to the first batch element, which produces incorrect frequency distributions.

Apply this diff to catch contract violations:

         if samples.dim() == 0:
+            if actual_batch_size != 1:
+                raise ValueError(
+                    f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
+                )
             # Single sample - assign to first batch element
             sample_idx = samples.item()
             counters[0, sample_idx] += 1
             total_samples_per_batch += 1

62-63: Guard against zero-sum masked probabilities to prevent NaNs.

If valid_mask zeros out all tokens (possible with edge-case top_k/top_p parameters), masked_probs.sum() will be zero, resulting in NaN values in expected_probs.

Apply this diff to add a safeguard:

 masked_probs = inp["probs"] * valid_mask.float()
-expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
+denom = masked_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
+expected_probs = masked_probs / denom

197-197: Validate cfg.sampling_tvd_threshold before comparison.

If cfg.sampling_tvd_threshold is None or missing, the comparison will raise a TypeError. Validate the attribute exists and has a valid value.

Apply this diff:

-        numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
+        tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
+        if tvd_thresh is None:
+            raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
+        numerical_incorrect = max_tvd > tvd_thresh

155-155: Convert Tensor to Python bool to avoid ambiguity error.

Using if not valid_mask[batch_idx, sample_idx] directly on a 0-dimensional boolean Tensor can raise "Boolean value of Tensor is ambiguous." Convert to a Python boolean first.

Apply this diff:

-                if not valid_mask[batch_idx, sample_idx]:
+                if not valid_mask[batch_idx, sample_idx].item():

297-303: Fix repeat_count calculation to prevent zero-batch and handle large batches.

Multiple critical issues:

  1. If original_batch_size > target_batch_size (e.g., 15000), repeat_count becomes 0, leading to actual_batch_size = 0, empty tensors, and division by zero at line 366.
  2. If inputs["probs"].shape[0] is 0 (empty batch), the fallback to 1 is semantically incorrect.
  3. Integer division underutilizes batch capacity when original_batch_size doesn't evenly divide target_batch_size.

Apply this diff to ensure at least one repeat and handle edge cases:

-    original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
+    original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
+    
+    if original_batch_size == 0:
+        return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device))
+    
     vocab_size = inputs["probs"].shape[-1]
 
     # Repeat entire input batch to fill up to target_batch_size for efficient sampling
     target_batch_size = 10000
-    repeat_count = target_batch_size // original_batch_size
+    repeat_count = max(1, target_batch_size // original_batch_size)
     actual_batch_size = repeat_count * original_batch_size
🧹 Nitpick comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)

250-250: Consider using custom exception classes for complex error messages.

Static analysis suggests avoiding long messages in exception constructors. While the current approach works, custom exception classes can improve maintainability for complex error scenarios.

Based on static analysis hints.

Also applies to: 265-265

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 18a032e and 39e2f19.

📒 Files selected for processing (1)
  • flashinfer_bench/bench/evaluators/sampling.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/utils.py (1)
  • dtype_str_to_torch_dtype (39-45)
flashinfer_bench/data/trace.py (2)
  • Correctness (72-95)
  • EvaluationStatus (126-146)
flashinfer_bench/bench/utils.py (2)
  • make_eval (261-279)
  • compute_error_stats (89-116)
flashinfer_bench/compile/runnable.py (1)
  • Runnable (6-38)
🪛 Ruff (0.14.3)
flashinfer_bench/bench/evaluators/sampling.py

250-250: Avoid specifying long messages outside the exception class

(TRY003)


265-265: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (3)
flashinfer_bench/bench/evaluators/sampling.py (3)

179-195: Well-structured per-batch TVD and error computation.

The implementation correctly computes TVD and error statistics for each batch element independently, then aggregates the worst-case metrics. This approach properly handles batch-wise validation and provides detailed diagnostics.


230-287: Robust thresholding mask implementation with proper tie-breaking.

The _compute_valid_sampling_mask function correctly implements:

  • Top-k tie-breaking by including all tokens with probability ≥ k-th largest
  • Top-p epsilon tolerance for numerical precision
  • Proper handling of edge cases (k=0, k≥vocab_size, p=0, p≥1)
  • Batch-wise mask computation with correct indexing

305-329: Proper input batching with dimension-aware repetition.

The input padding logic correctly handles different tensor types:

  • Repeats probability tensors along batch dimension
  • Properly expands scalar and batched sampling parameters
  • Maintains correct shapes for multi-dimensional inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants