Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/user-guide/scenario-definition.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ Scenarios are optional. If you don’t provide any and you supply a dataset, gen
- Example: `N(480,240)/(300,150)`
- Uniform: `U(min_input_tokens,max_input_tokens)/(min_output_tokens,max_output_tokens)` or `U(max_input_tokens,max_output_tokens)`
- Examples: `U(50,100)/(200,250)` or `U(100,200)`
- Prefix Repetition (for KV cache benchmarking): `P(prefix_len,suffix_len)/output_len`
- Example: `P(2000,500)/200`
- All requests share the same prefix (first request caches it, subsequent requests reuse cached KV)
- Each request has a unique suffix to ensure different completions
- Useful for benchmarking automatic prefix caching (APC), chunked prefill, and TTFT improvements

- Embeddings
- Embedding: `E(tokens_per_document)`
Expand Down
5 changes: 5 additions & 0 deletions genai_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ def benchmark(
sanitized_scenario_str = sanitize_string(scenario_str)
runner.update_scenario(scenario_str)

# Reset prefix cache for new scenario to ensure fresh prefix
# This is critical for prefix repetition scenarios to work correctly
if hasattr(sampler, "reset_prefix_cache"):
sampler.reset_prefix_cache()

# Store metrics for current scenario for interim plot
scenario_metrics = {
"data": {},
Expand Down
8 changes: 8 additions & 0 deletions genai_bench/cli/option_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ def experiment_options(func):
- Examples: U(50,100)/(200,250)
U(100,200)

\b
4. **Prefix Repetition (P)**: For KV cache benchmarking.
- Format: P(prefix_len,suffix_len)/output_len
- Example: P(2000,500)/200
- All requests share same prefix (cached after 1st request)
- Each request has unique suffix
- Tests automatic prefix caching (APC) and TTFT improvements

\b
Supported modalities are:

Expand Down
133 changes: 131 additions & 2 deletions genai_bench/sampling/text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import random
from typing import Any, Dict, List, Optional

Expand All @@ -11,6 +12,7 @@
)
from genai_bench.sampling.base import Sampler
from genai_bench.scenarios.base import EmbeddingDistribution, Scenario, TextDistribution
from genai_bench.scenarios.text import PrefixRepetitionScenario

logger = init_logger(__name__)

Expand Down Expand Up @@ -42,6 +44,11 @@ def __init__(
self.data = data
self.batch_size = 1 # Default batch size

# Cache for shared prefixes in prefix repetition scenarios
# Key: scenario identifier, Value: generated prefix text
self._shared_prefix_cache: Dict[str, str] = {}
self._suffix_counter = 0

def sample(self, scenario: Optional[Scenario]) -> UserRequest:
"""
Samples a request based on the scenario.
Expand Down Expand Up @@ -69,6 +76,10 @@ def _sample_chat_request(self, scenario: Optional[Scenario]) -> UserChatRequest:
num_input_tokens, num_output_tokens = None, None
self.additional_request_params["ignore_eos"] = False
else:
# Check if this is a prefix repetition scenario
if isinstance(scenario, PrefixRepetitionScenario):
return self._sample_prefix_repetition_request(scenario)

# Use scenario-based sampling
self._validate_scenario(scenario)
num_input_tokens, num_output_tokens = scenario.sample()
Expand Down Expand Up @@ -187,16 +198,24 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str:
while left_tokens_to_sample > 0:
random.shuffle(data_copy)
for line in data_copy:
line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
# Tokenize line with space prefix to match how it will be concatenated
line_with_space = (" " if prompt else "") + line
line_tokens = self.tokenizer.encode(
line_with_space, add_special_tokens=False
)
num_line_tokens = len(line_tokens)

if num_line_tokens > left_tokens_to_sample:
# Truncate at token level, decode only needed tokens
truncated_text = self.tokenizer.decode(
line_tokens[:left_tokens_to_sample], skip_special_tokens=True
)
prompt += (" " if prompt else "") + truncated_text
return prompt
prompt += line

# Add line with space separator
# (consistent with truncated text handling)
prompt += (" " if prompt else "") + line
left_tokens_to_sample -= num_line_tokens
return prompt

Expand All @@ -222,3 +241,113 @@ def _check_discrepancy(
f"num_prefill_tokens={num_prefill_tokens}, "
f"discrepancy={discrepancy}"
)

def _sample_prefix_repetition_request(self, scenario) -> UserChatRequest:
"""Generate request with shared prefix for KV cache benchmarking.

This method creates requests where all concurrent requests share the
exact same prefix text, enabling benchmarking of:
- KV cache hit rates and speedups
- Automatic prefix caching (APC) performance
- Chunked prefill efficiency
- Time To First Token (TTFT) improvements

Args:
scenario: PrefixRepetitionScenario with prefix_len, suffix_len, output_len

Returns:
UserChatRequest with shared prefix + unique suffix
"""
prefix_len, suffix_len, output_len = scenario.sample()

# Get or create shared prefix (cached for ALL requests in this scenario run)
cache_key = f"prefix_{prefix_len}"
if cache_key not in self._shared_prefix_cache:
# Generate the shared prefix once
prefix = self._sample_text(prefix_len)
self._shared_prefix_cache[cache_key] = prefix

# Calculate hash for verification
prefix_hash = hashlib.md5(prefix.encode()).hexdigest()[:8]

logger.info(
f"🔑 Generated shared prefix ({prefix_len} tokens) "
f"for KV cache benchmarking. "
f"All subsequent requests in this scenario will reuse this prefix."
)
logger.debug(
f" Prefix hash: {prefix_hash} | " f"Preview: {prefix[:100]}..."
)
else:
prefix = self._shared_prefix_cache[cache_key]

# Log cache reuse (only for first few to avoid spam)
if self._suffix_counter < 5:
prefix_hash = hashlib.md5(prefix.encode()).hexdigest()[:8]
logger.debug(
f"♻️ Reusing cached prefix (hash: {prefix_hash}) "
f"for request #{self._suffix_counter + 1}"
)

# Generate unique suffix for THIS specific request
suffix = self._sample_text(suffix_len)
self._suffix_counter += 1

# Log suffix info for first few requests
if self._suffix_counter <= 5:
suffix_hash = hashlib.md5(suffix.encode()).hexdigest()[:8]
suffix_actual_tokens = self.get_token_length(suffix)
logger.debug(
f"📝 Request #{self._suffix_counter}: "
f"Unique suffix generated (hash: {suffix_hash}), "
f"requested {suffix_len} tokens, actual {suffix_actual_tokens} tokens"
)

# Combine prefix + separator + suffix
# The separator helps distinguish requests while keeping prefix identical
separator = f"\n\n--- Request #{self._suffix_counter} ---\n\n"
prompt = f"{prefix}{separator}{suffix}"

num_prefill_tokens = self.get_token_length(prompt)

# Log actual token breakdown for first request
if self._suffix_counter <= 2:
prefix_tokens = self.get_token_length(prefix)
separator_tokens = self.get_token_length(separator)
suffix_tokens = self.get_token_length(suffix)
logger.debug(
f"🔍 Token breakdown for request #{self._suffix_counter}: "
f"prefix={prefix_tokens}, separator={separator_tokens}, "
f"suffix={suffix_tokens}, total={num_prefill_tokens} "
f"(expected ~{prefix_len + suffix_len + 20})"
)

# Expected tokens: prefix + suffix + separator overhead (~20 tokens)
expected_tokens = prefix_len + suffix_len + 20
self._check_discrepancy(expected_tokens, num_prefill_tokens, threshold=0.15)

# Set ignore_eos to ensure we get the expected output length
self.additional_request_params["ignore_eos"] = True

return UserChatRequest(
model=self.model,
prompt=prompt,
num_prefill_tokens=num_prefill_tokens,
max_tokens=output_len,
additional_request_params=self.additional_request_params,
)

def reset_prefix_cache(self):
"""Clear the prefix cache and reset counter.

This should be called between different scenario runs to ensure
each scenario gets a fresh prefix.
"""
if self._suffix_counter > 0:
logger.info(
f"🔄 Resetting prefix cache. "
f"Previous scenario generated {self._suffix_counter} requests "
f"with {len(self._shared_prefix_cache)} cached prefix(es)."
)
self._shared_prefix_cache.clear()
self._suffix_counter = 0
1 change: 1 addition & 0 deletions genai_bench/scenarios/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TextDistribution(Enum):
NORMAL = "N"
DETERMINISTIC = "D"
UNIFORM = "U"
PREFIX_REPETITION = "P"


class EmbeddingDistribution(Enum):
Expand Down
64 changes: 64 additions & 0 deletions genai_bench/scenarios/text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -222,3 +223,66 @@ def parse(cls, params_str: str) -> "ReRankScenario":
return cls(
tokens_per_document=tokens_per_document, tokens_per_query=tokens_per_query
)


class PrefixRepetitionScenario(Scenario):
"""
Prefix repetition scenario for KV cache benchmarking.

All concurrent requests share the same prefix but have unique suffixes.
This enables benchmarking of KV cache performance, chunked prefill efficiency,
and automatic prefix caching (APC) features in LLM serving engines.

Format: P(prefix_len,suffix_len)/output_len
Example: P(2000,500)/200

In this example:
- All requests share a 2000-token prefix (cached after first request)
- Each request has a unique 500-token suffix
- Expected output is 200 tokens

This scenario is particularly useful for:
- Testing KV cache hit rates and speedups
- Benchmarking prefill performance with cached prefixes
- Measuring Time To First Token (TTFT) improvements
- Evaluating automatic prefix caching implementations
"""

scenario_type = TextDistribution.PREFIX_REPETITION
validation_pattern = r"^P\(\d+,\d+\)/\d+$"

def __init__(self, prefix_len: int, suffix_len: int, output_len: int):
self.prefix_len = prefix_len
self.suffix_len = suffix_len
self.output_len = output_len

def sample(self) -> Tuple[int, int, int]:
"""Returns (prefix_len, suffix_len, output_len)"""
return self.prefix_len, self.suffix_len, self.output_len

def to_string(self) -> str:
"""
Returns the prefix repetition scenario back in its string representation.
For example P(2000,500)/200.
"""
return f"P({self.prefix_len},{self.suffix_len})/{self.output_len}"

@classmethod
def parse(cls, params_str: str) -> "PrefixRepetitionScenario":
"""
Parse the prefix repetition scenario from a string.

Example: "(2000,500)/200" -> PrefixRepetitionScenario(2000, 500, 200)
"""
# Parse P(prefix_len,suffix_len)/output_len
# params_str will be "(2000,500)/200"
match = re.match(r"\((\d+),(\d+)\)/(\d+)", params_str)
if not match:
raise ValueError(
f"Invalid prefix repetition format: {params_str}. "
f"Expected format: (prefix_len,suffix_len)/output_len"
)
prefix_len = int(match.group(1))
suffix_len = int(match.group(2))
output_len = int(match.group(3))
return cls(prefix_len, suffix_len, output_len)
Loading
Loading