@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
129129 check_answers (indices , answer , test_texts )
130130
131131
132- def prep_prompts (batch_size : int ):
132+ def prep_prompts (batch_size : int , ln_range : tuple [ int , int ] = ( 800 , 1100 ) ):
133133 """
134134 Generate prompts which a bunch of assignments,
135135 then asking for the value of one of them.
136136 The prompt is just under 10k tokens; sliding window is 4k
137137 so the answer is outside sliding window, but should still be correct.
138+
139+ Args:
140+ batch_size: number of prompts to generate
141+ ln_range: an argument to control the length of the prompt
138142 """
139143 prompts : list [str ] = []
140144 answer : list [int ] = []
@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
145149 indices .append (idx )
146150 prompt = "```python\n # We set a number of variables, " + \
147151 f"x{ idx } will be important later\n "
148- ln = random .randint (800 , 1100 )
152+ ln = random .randint (* ln_range )
149153 for k in range (30 , ln ):
150154 v = random .randint (10 , 99 )
151155 if k == idx :
@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
157161 return prompts , answer , indices
158162
159163
160- def check_answers (indices : list [int ], answer : list [int ], outputs : list [str ]):
164+ def check_answers (indices : list [int ],
165+ answer : list [int ],
166+ outputs : list [str ],
167+ accept_rate : float = 0.7 ):
161168 answer2 = [int (text [0 :2 ].strip ()) for text in outputs ]
162169 print (list (zip (indices , zip (answer , answer2 ))))
163170 numok = 0
@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
166173 numok += 1
167174 frac_ok = numok / len (answer )
168175 print (f"Num OK: { numok } /{ len (answer )} { frac_ok } " )
169- assert frac_ok > 0.7
176+ assert frac_ok >= accept_rate
170177
171178
172179def check_window (prompts : list [str ]):
0 commit comments