| 
8 | 8 | import importlib  | 
9 | 9 | import json  | 
10 | 10 | import os  | 
 | 11 | +import random  | 
11 | 12 | import signal  | 
12 | 13 | import subprocess  | 
13 | 14 | import sys  | 
@@ -1150,3 +1151,49 @@ def override_cutlass_fp8_supported(value: bool):  | 
1150 | 1151 |             "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",  | 
1151 | 1152 |             return_value=value):  | 
1152 | 1153 |         yield  | 
 | 1154 | + | 
 | 1155 | + | 
 | 1156 | +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):  | 
 | 1157 | +    """  | 
 | 1158 | +    Generate prompts which a bunch of assignments,  | 
 | 1159 | +    then asking for the value of one of them.  | 
 | 1160 | +    The prompt is just under 10k tokens; sliding window is 4k  | 
 | 1161 | +    so the answer is outside sliding window, but should still be correct.  | 
 | 1162 | +    Args:  | 
 | 1163 | +        batch_size: number of prompts to generate  | 
 | 1164 | +        ln_range: an argument to control the length of the prompt  | 
 | 1165 | +    """  | 
 | 1166 | +    prompts: list[str] = []  | 
 | 1167 | +    answer: list[int] = []  | 
 | 1168 | +    indices: list[int] = []  | 
 | 1169 | +    random.seed(1)  | 
 | 1170 | +    for _ in range(batch_size):  | 
 | 1171 | +        idx = random.randint(30, 90)  | 
 | 1172 | +        indices.append(idx)  | 
 | 1173 | +        prompt = "```python\n# We set a number of variables, " + \  | 
 | 1174 | +                 f"x{idx} will be important later\n"  | 
 | 1175 | +        ln = random.randint(*ln_range)  | 
 | 1176 | +        for k in range(30, ln):  | 
 | 1177 | +            v = random.randint(10, 99)  | 
 | 1178 | +            if k == idx:  | 
 | 1179 | +                answer.append(v)  | 
 | 1180 | +            prompt += f"x{k} = {v}\n"  | 
 | 1181 | +        prompt += f"# Now, we check the value of x{idx}:\n"  | 
 | 1182 | +        prompt += f"assert x{idx} == "  | 
 | 1183 | +        prompts.append(prompt)  | 
 | 1184 | +    return prompts, answer, indices  | 
 | 1185 | + | 
 | 1186 | + | 
 | 1187 | +def check_answers(indices: list[int],  | 
 | 1188 | +                  answer: list[int],  | 
 | 1189 | +                  outputs: list[str],  | 
 | 1190 | +                  accept_rate: float = 0.7):  | 
 | 1191 | +    answer2 = [int(text[0:2].strip()) for text in outputs]  | 
 | 1192 | +    print(list(zip(indices, zip(answer, answer2))))  | 
 | 1193 | +    numok = 0  | 
 | 1194 | +    for a1, a2 in zip(answer, answer2):  | 
 | 1195 | +        if a1 == a2:  | 
 | 1196 | +            numok += 1  | 
 | 1197 | +    frac_ok = numok / len(answer)  | 
 | 1198 | +    print(f"Num OK: {numok}/{len(answer)} {frac_ok}")  | 
 | 1199 | +    assert frac_ok >= accept_rate  | 
0 commit comments