Skip to content

Commit

Permalink
Benchmark Fix : Remove special tokens from warmup prompts (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#140)

Summary:
When sampling words at random for prompt generation, we sometimes pick
up the `<pad>` token.
The Tokenizer doesn't recognize this as a special token and leaves it in
the prompt as-is. This causes the backend to fail with,
```
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [312,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed
```
like errors.

Test:
Manual tests

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
varun-sundar-rabindranath and Varun Sundar Rabindranath authored Mar 20, 2024
1 parent e072350 commit d630f71
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions neuralmagic/benchmarks/scripts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ def get_benchmarking_context() -> dict:
}


def remove_special_tokens_and_decode(
prompt_ids: list[int], tokenizer: PreTrainedTokenizerBase) -> str:
# Remove special tokens from prompt ids
prompt_ids = list(
filter(lambda id: id not in tokenizer.all_special_ids, prompt_ids))
return tokenizer.decode(prompt_ids)


def generate_synthetic_requests(
num_input_tokens: int, num_output_tokens: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
Expand Down Expand Up @@ -88,7 +80,7 @@ def generate_synthetic_requests(
continue

prompt_ids = prompt_ids[:num_input_tokens]
prompt = remove_special_tokens_and_decode(prompt_ids, tokenizer)
prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True)

sampled_requests.append((prompt, num_input_tokens, num_output_tokens))

Expand All @@ -103,15 +95,17 @@ def warmup_requests(tokenizer: PreTrainedTokenizerBase,
"""
Given a tokenizer, generate `num_requests` requests used for warmup
"""
words = list(tokenizer.get_vocab().keys())
all_words = list(tokenizer.get_vocab().keys())
# Remove special tokens like <s>, </s>, <pad> etc. from all_words
words = list(filter(lambda word: not word.startswith('<'), all_words))
requests = []
for _ in range(num_requests):
# We make up random prompts for warmups in order to avoid the effects of
# prefix caching during actual benchmarking.
prompt = " ".join(random.choices(words, k=num_input_tokens))
prompt_ids = tokenizer(prompt).input_ids
prompt_ids = prompt_ids[:num_input_tokens]
prompt = remove_special_tokens_and_decode(prompt_ids, tokenizer)
prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True)
requests.append((prompt, num_input_tokens, num_output_tokens))
return requests

Expand Down

0 comments on commit d630f71

Please sign in to comment.