diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 66ff8f7a54d3..ba9c3bebc437 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -8,8 +8,7 @@ from vllm.platforms import current_platform from ..conftest import HfRunner, VllmRunner -from ..core.block.e2e.test_correctness_sliding_window import prep_prompts -from ..utils import multi_gpu_test +from ..utils import multi_gpu_test, prep_prompts from .utils import check_logprobs_close diff --git a/tests/utils.py b/tests/utils.py index 16e1e6039329..9a27c3de4533 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,6 +8,7 @@ import importlib import json import os +import random import signal import subprocess import sys @@ -1150,3 +1151,49 @@ def override_cutlass_fp8_supported(value: bool): "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", return_value=value): yield + + +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + Args: + batch_size: number of prompts to generate + ln_range: an argument to control the length of the prompt + """ + prompts: list[str] = [] + answer: list[int] = [] + indices: list[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(*ln_range) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: list[int], + answer: list[int], + outputs: list[str], + accept_rate: float = 0.7): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok >= accept_rate diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 4dfe1d3bb33f..5b0c15472251 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -6,8 +6,7 @@ from vllm import LLM, SamplingParams -from ...core.block.e2e.test_correctness_sliding_window import (check_answers, - prep_prompts) +from ...utils import check_answers, prep_prompts @dataclass