@@ -200,13 +200,32 @@ def _verify_saved_data(
200200 @parameterized .named_parameters (
201201 dict (
202202 testcase_name = "_prefill_no_skip_save_2_drop_jax" ,
203+ use_precompiled_swap_ops = False ,
204+ num_skip_leading_tokens = 0 ,
205+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
206+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
207+ num_blocks_to_save = 2 ,
208+ ),
209+ dict (
210+ testcase_name = "_prefill_no_skip_save_2_drop_jax_precompiled" ,
211+ use_precompiled_swap_ops = True ,
203212 num_skip_leading_tokens = 0 ,
204213 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
205214 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
206215 num_blocks_to_save = 2 ,
207216 ),
208217 dict (
209218 testcase_name = "_prefill_no_skip_save_2_drop_pallas" ,
219+ use_precompiled_swap_ops = False ,
220+ num_skip_leading_tokens = 0 ,
221+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
222+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
223+ num_blocks_to_save = 2 ,
224+ swap_op_type = "pallas" ,
225+ ),
226+ dict (
227+ testcase_name = "_prefill_no_skip_save_2_drop_pallas_precompiled" ,
228+ use_precompiled_swap_ops = True ,
210229 num_skip_leading_tokens = 0 ,
211230 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
212231 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
@@ -219,13 +238,32 @@ def _verify_saved_data(
219238 # block and assign 3 blocks to save.
220239 dict (
221240 testcase_name = "_prefill_no_skip_save_2_pad_jax" ,
241+ use_precompiled_swap_ops = False ,
242+ num_skip_leading_tokens = 0 ,
243+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
244+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
245+ num_blocks_to_save = 3 ,
246+ ),
247+ dict (
248+ testcase_name = "_prefill_no_skip_save_2_pad_jax_precompiled" ,
249+ use_precompiled_swap_ops = True ,
222250 num_skip_leading_tokens = 0 ,
223251 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
224252 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
225253 num_blocks_to_save = 3 ,
226254 ),
227255 dict (
228256 testcase_name = "_prefill_no_skip_save_2_pad_pallas" ,
257+ use_precompiled_swap_ops = False ,
258+ num_skip_leading_tokens = 0 ,
259+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
260+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
261+ num_blocks_to_save = 3 ,
262+ swap_op_type = "pallas" ,
263+ ),
264+ dict (
265+ testcase_name = "_prefill_no_skip_save_2_pad_pallas_precompiled" ,
266+ use_precompiled_swap_ops = True ,
229267 num_skip_leading_tokens = 0 ,
230268 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
231269 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
@@ -234,27 +272,65 @@ def _verify_saved_data(
234272 ),
235273 dict (
236274 testcase_name = "_prefill_skip_2_save_2_drop" ,
275+ use_precompiled_swap_ops = False ,
276+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
277+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
278+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 + 10 ,
279+ num_blocks_to_save = 2 ,
280+ ),
281+ dict (
282+ testcase_name = "_prefill_skip_2_save_2_drop_precompiled" ,
283+ use_precompiled_swap_ops = True ,
237284 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
238285 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 ,
239286 num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 + 10 ,
240287 num_blocks_to_save = 2 ,
241288 ),
242289 dict (
243290 testcase_name = "_prefill_skip_2_save_2_pad" ,
291+ use_precompiled_swap_ops = False ,
292+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
293+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
294+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 + 10 ,
295+ num_blocks_to_save = 3 ,
296+ ),
297+ dict (
298+ testcase_name = "_prefill_skip_2_save_2_pad_precompiled" ,
299+ use_precompiled_swap_ops = True ,
244300 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
245301 num_tokens_to_save = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
246302 num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 + 10 ,
247303 num_blocks_to_save = 3 ,
248304 ),
249305 dict (
250306 testcase_name = "_decode_skip_3_save_1" ,
307+ use_precompiled_swap_ops = False ,
308+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 3 ,
309+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE ,
310+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 ,
311+ num_blocks_to_save = 1 ,
312+ ),
313+ dict (
314+ testcase_name = "_decode_skip_3_save_1_precompiled" ,
315+ use_precompiled_swap_ops = True ,
251316 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 3 ,
252317 num_tokens_to_save = _DEFAULT_BLOCK_SIZE ,
253318 num_total_tokens = _DEFAULT_BLOCK_SIZE * 4 ,
254319 num_blocks_to_save = 1 ,
255320 ),
256321 dict (
257322 testcase_name = "_no_save" ,
323+ use_precompiled_swap_ops = False ,
324+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
325+ num_tokens_to_save = 0 ,
326+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
327+ num_blocks_to_save = 0 ,
328+ is_final_save = False ,
329+ skip_save = False ,
330+ ),
331+ dict (
332+ testcase_name = "_no_save_precompiled" ,
333+ use_precompiled_swap_ops = True ,
258334 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
259335 num_tokens_to_save = 0 ,
260336 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
@@ -264,6 +340,17 @@ def _verify_saved_data(
264340 ),
265341 dict (
266342 testcase_name = "_final_save_save_1_drop" ,
343+ use_precompiled_swap_ops = False ,
344+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
345+ num_tokens_to_save = _DEFAULT_BLOCK_SIZE ,
346+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 3 + 10 ,
347+ num_blocks_to_save = 1 ,
348+ is_final_save = True ,
349+ skip_save = False ,
350+ ),
351+ dict (
352+ testcase_name = "_final_save_save_1_drop_precompiled" ,
353+ use_precompiled_swap_ops = True ,
267354 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
268355 num_tokens_to_save = _DEFAULT_BLOCK_SIZE ,
269356 num_total_tokens = _DEFAULT_BLOCK_SIZE * 3 + 10 ,
@@ -273,6 +360,17 @@ def _verify_saved_data(
273360 ),
274361 dict (
275362 testcase_name = "_final_save_save_1_pad" ,
363+ use_precompiled_swap_ops = False ,
364+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
365+ num_tokens_to_save = 10 ,
366+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
367+ num_blocks_to_save = 1 ,
368+ is_final_save = True ,
369+ skip_save = False ,
370+ ),
371+ dict (
372+ testcase_name = "_final_save_save_1_pad_precompiled" ,
373+ use_precompiled_swap_ops = True ,
276374 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
277375 num_tokens_to_save = 10 ,
278376 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 + 10 ,
@@ -282,6 +380,17 @@ def _verify_saved_data(
282380 ),
283381 dict (
284382 testcase_name = "_final_save_without_data" ,
383+ use_precompiled_swap_ops = False ,
384+ num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
385+ num_tokens_to_save = 0 ,
386+ num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
387+ num_blocks_to_save = 0 ,
388+ is_final_save = True ,
389+ skip_save = True ,
390+ ),
391+ dict (
392+ testcase_name = "_final_save_without_data_precompiled" ,
393+ use_precompiled_swap_ops = True ,
285394 num_skip_leading_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
286395 num_tokens_to_save = 0 ,
287396 num_total_tokens = _DEFAULT_BLOCK_SIZE * 2 ,
@@ -292,6 +401,7 @@ def _verify_saved_data(
292401 )
293402 def test_tpu_connector_save (
294403 self ,
404+ use_precompiled_swap_ops : bool ,
295405 num_skip_leading_tokens : int ,
296406 num_tokens_to_save : int ,
297407 num_total_tokens : int ,
@@ -300,6 +410,8 @@ def test_tpu_connector_save(
300410 skip_save : bool = False ,
301411 swap_op_type : str = "jax" ,
302412 ):
413+ os .environ [
414+ "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE" ] = "0" if use_precompiled_swap_ops else "1"
303415
304416 # Prepare and Execute Save
305417 total_token_ids = list (range (num_total_tokens ))
@@ -422,25 +534,42 @@ def test_tpu_connector_save(
422534
423535 @parameterized .named_parameters (
424536 dict (
425- testcase_name = "_2_steps" ,
537+ testcase_name = "_2_steps_nobucket" ,
538+ use_precompiled_swap_ops = False ,
539+ num_blocks_step1 = 2 ,
540+ num_blocks_step2 = 1 ,
541+ ),
542+ dict (
543+ testcase_name = "_2_steps_bucketed_precompiled" ,
544+ use_precompiled_swap_ops = True ,
426545 num_blocks_step1 = 2 ,
427546 num_blocks_step2 = 1 ,
428547 ),
429548 dict (
430549 testcase_name = "_zero_token_step2" ,
550+ use_precompiled_swap_ops = False ,
551+ num_blocks_step1 = 2 ,
552+ num_blocks_step2 = 0 ,
553+ ),
554+ dict (
555+ testcase_name = "_zero_token_step2_bucketed_precompiled" ,
556+ use_precompiled_swap_ops = True ,
431557 num_blocks_step1 = 2 ,
432558 num_blocks_step2 = 0 ,
433559 ),
434560 )
435561 def test_tpu_connector_multi_step_save (
436562 self ,
563+ use_precompiled_swap_ops : bool ,
437564 num_blocks_step1 : int ,
438565 num_blocks_step2 : int ,
439566 ):
440567 """
441568 Tests that the TPUConnectorWorker correctly saves the KV cache in multiple
442569 steps, respecting the save watermark (skip_leading_tokens).
443570 """
571+ os .environ [
572+ "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE" ] = "0" if use_precompiled_swap_ops else "1"
444573 num_tokens_step1 = num_blocks_step1 * self .block_size
445574 num_tokens_step2 = num_blocks_step2 * self .block_size
446575 logger .info (
@@ -589,31 +718,64 @@ def test_tpu_connector_multi_step_save(
589718 @parameterized .named_parameters (
590719 dict (
591720 testcase_name = "_full_load_jax" ,
721+ use_precompiled_swap_ops = False ,
722+ swap_op_type = "jax" ,
723+ num_matched_blocks = 4 ,
724+ num_computed_blocks = 0 ,
725+ ),
726+ dict (
727+ testcase_name = "_full_load_jax_precompiled" ,
728+ use_precompiled_swap_ops = True ,
592729 swap_op_type = "jax" ,
593730 num_matched_blocks = 4 ,
594731 num_computed_blocks = 0 ,
595732 ),
596733 dict (
597734 testcase_name = "_delta_load_jax" ,
735+ use_precompiled_swap_ops = False ,
736+ swap_op_type = "jax" ,
737+ num_matched_blocks = 4 ,
738+ num_computed_blocks = 1 ,
739+ ),
740+ dict (
741+ testcase_name = "_delta_load_jax_precompiled" ,
742+ use_precompiled_swap_ops = True ,
598743 swap_op_type = "jax" ,
599744 num_matched_blocks = 4 ,
600745 num_computed_blocks = 1 ,
601746 ),
602747 dict (
603748 testcase_name = "_delta_load_pallas" ,
749+ use_precompiled_swap_ops = False ,
750+ swap_op_type = "pallas" ,
751+ num_matched_blocks = 4 ,
752+ num_computed_blocks = 1 ,
753+ ),
754+ dict (
755+ testcase_name = "_delta_load_pallas_precompiled" ,
756+ use_precompiled_swap_ops = True ,
604757 swap_op_type = "pallas" ,
605758 num_matched_blocks = 4 ,
606759 num_computed_blocks = 1 ,
607760 ),
608761 dict (
609762 testcase_name = "_no_load_jax" ,
763+ use_precompiled_swap_ops = False ,
764+ swap_op_type = "jax" ,
765+ num_matched_blocks = 1 ,
766+ num_computed_blocks = 1 ,
767+ ),
768+ dict (
769+ testcase_name = "_no_load_jax_precompiled" ,
770+ use_precompiled_swap_ops = True ,
610771 swap_op_type = "jax" ,
611772 num_matched_blocks = 1 ,
612773 num_computed_blocks = 1 ,
613774 ),
614775 )
615776 def test_tpu_connector_load (
616777 self ,
778+ use_precompiled_swap_ops : bool ,
617779 swap_op_type : str ,
618780 num_matched_blocks : int ,
619781 num_computed_blocks : int = 0 ,
@@ -654,6 +816,8 @@ def test_tpu_connector_load(
654816 - Assert that the parts of the destination cache that should not have
655817 been touched remain zero.
656818 """
819+ os .environ [
820+ "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE" ] = "0" if use_precompiled_swap_ops else "1"
657821 num_matched_tokens = num_matched_blocks * self .block_size
658822 num_computed_tokens = num_computed_blocks * self .block_size
659823 if num_matched_blocks > self .num_blocks :
0 commit comments