Skip to content

Commit 23a78cd

Browse files
committed
fix
Signed-off-by: Saikat Roychowdhury <saikat.royc85@gmail.com>
1 parent 2e42f61 commit 23a78cd

File tree

2 files changed

+213
-18
lines changed

2 files changed

+213
-18
lines changed

tests/distributed/cpu_offloading_worker_test.py

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,32 @@ def _verify_saved_data(
200200
@parameterized.named_parameters(
201201
dict(
202202
testcase_name="_prefill_no_skip_save_2_drop_jax",
203+
use_precompiled_swap_ops=False,
204+
num_skip_leading_tokens=0,
205+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
206+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
207+
num_blocks_to_save=2,
208+
),
209+
dict(
210+
testcase_name="_prefill_no_skip_save_2_drop_jax_precompiled",
211+
use_precompiled_swap_ops=True,
203212
num_skip_leading_tokens=0,
204213
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
205214
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
206215
num_blocks_to_save=2,
207216
),
208217
dict(
209218
testcase_name="_prefill_no_skip_save_2_drop_pallas",
219+
use_precompiled_swap_ops=False,
220+
num_skip_leading_tokens=0,
221+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
222+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
223+
num_blocks_to_save=2,
224+
swap_op_type="pallas",
225+
),
226+
dict(
227+
testcase_name="_prefill_no_skip_save_2_drop_pallas_precompiled",
228+
use_precompiled_swap_ops=True,
210229
num_skip_leading_tokens=0,
211230
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
212231
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
@@ -219,13 +238,32 @@ def _verify_saved_data(
219238
# block and assign 3 blocks to save.
220239
dict(
221240
testcase_name="_prefill_no_skip_save_2_pad_jax",
241+
use_precompiled_swap_ops=False,
242+
num_skip_leading_tokens=0,
243+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
244+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
245+
num_blocks_to_save=3,
246+
),
247+
dict(
248+
testcase_name="_prefill_no_skip_save_2_pad_jax_precompiled",
249+
use_precompiled_swap_ops=True,
222250
num_skip_leading_tokens=0,
223251
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
224252
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
225253
num_blocks_to_save=3,
226254
),
227255
dict(
228256
testcase_name="_prefill_no_skip_save_2_pad_pallas",
257+
use_precompiled_swap_ops=False,
258+
num_skip_leading_tokens=0,
259+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
260+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
261+
num_blocks_to_save=3,
262+
swap_op_type="pallas",
263+
),
264+
dict(
265+
testcase_name="_prefill_no_skip_save_2_pad_pallas_precompiled",
266+
use_precompiled_swap_ops=True,
229267
num_skip_leading_tokens=0,
230268
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
231269
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
@@ -234,27 +272,65 @@ def _verify_saved_data(
234272
),
235273
dict(
236274
testcase_name="_prefill_skip_2_save_2_drop",
275+
use_precompiled_swap_ops=False,
276+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
277+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
278+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
279+
num_blocks_to_save=2,
280+
),
281+
dict(
282+
testcase_name="_prefill_skip_2_save_2_drop_precompiled",
283+
use_precompiled_swap_ops=True,
237284
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
238285
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
239286
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
240287
num_blocks_to_save=2,
241288
),
242289
dict(
243290
testcase_name="_prefill_skip_2_save_2_pad",
291+
use_precompiled_swap_ops=False,
292+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
293+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
294+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
295+
num_blocks_to_save=3,
296+
),
297+
dict(
298+
testcase_name="_prefill_skip_2_save_2_pad_precompiled",
299+
use_precompiled_swap_ops=True,
244300
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
245301
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
246302
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
247303
num_blocks_to_save=3,
248304
),
249305
dict(
250306
testcase_name="_decode_skip_3_save_1",
307+
use_precompiled_swap_ops=False,
308+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3,
309+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
310+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4,
311+
num_blocks_to_save=1,
312+
),
313+
dict(
314+
testcase_name="_decode_skip_3_save_1_precompiled",
315+
use_precompiled_swap_ops=True,
251316
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3,
252317
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
253318
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4,
254319
num_blocks_to_save=1,
255320
),
256321
dict(
257322
testcase_name="_no_save",
323+
use_precompiled_swap_ops=False,
324+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
325+
num_tokens_to_save=0,
326+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
327+
num_blocks_to_save=0,
328+
is_final_save=False,
329+
skip_save=False,
330+
),
331+
dict(
332+
testcase_name="_no_save_precompiled",
333+
use_precompiled_swap_ops=True,
258334
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
259335
num_tokens_to_save=0,
260336
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
@@ -264,6 +340,17 @@ def _verify_saved_data(
264340
),
265341
dict(
266342
testcase_name="_final_save_save_1_drop",
343+
use_precompiled_swap_ops=False,
344+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
345+
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
346+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10,
347+
num_blocks_to_save=1,
348+
is_final_save=True,
349+
skip_save=False,
350+
),
351+
dict(
352+
testcase_name="_final_save_save_1_drop_precompiled",
353+
use_precompiled_swap_ops=True,
267354
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
268355
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
269356
num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10,
@@ -273,6 +360,17 @@ def _verify_saved_data(
273360
),
274361
dict(
275362
testcase_name="_final_save_save_1_pad",
363+
use_precompiled_swap_ops=False,
364+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
365+
num_tokens_to_save=10,
366+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
367+
num_blocks_to_save=1,
368+
is_final_save=True,
369+
skip_save=False,
370+
),
371+
dict(
372+
testcase_name="_final_save_save_1_pad_precompiled",
373+
use_precompiled_swap_ops=True,
276374
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
277375
num_tokens_to_save=10,
278376
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
@@ -282,6 +380,17 @@ def _verify_saved_data(
282380
),
283381
dict(
284382
testcase_name="_final_save_without_data",
383+
use_precompiled_swap_ops=False,
384+
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
385+
num_tokens_to_save=0,
386+
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
387+
num_blocks_to_save=0,
388+
is_final_save=True,
389+
skip_save=True,
390+
),
391+
dict(
392+
testcase_name="_final_save_without_data_precompiled",
393+
use_precompiled_swap_ops=True,
285394
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
286395
num_tokens_to_save=0,
287396
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
@@ -292,6 +401,7 @@ def _verify_saved_data(
292401
)
293402
def test_tpu_connector_save(
294403
self,
404+
use_precompiled_swap_ops: bool,
295405
num_skip_leading_tokens: int,
296406
num_tokens_to_save: int,
297407
num_total_tokens: int,
@@ -300,6 +410,8 @@ def test_tpu_connector_save(
300410
skip_save: bool = False,
301411
swap_op_type: str = "jax",
302412
):
413+
os.environ[
414+
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"
303415

304416
# Prepare and Execute Save
305417
total_token_ids = list(range(num_total_tokens))
@@ -422,25 +534,42 @@ def test_tpu_connector_save(
422534

423535
@parameterized.named_parameters(
424536
dict(
425-
testcase_name="_2_steps",
537+
testcase_name="_2_steps_nobucket",
538+
use_precompiled_swap_ops=False,
539+
num_blocks_step1=2,
540+
num_blocks_step2=1,
541+
),
542+
dict(
543+
testcase_name="_2_steps_bucketed_precompiled",
544+
use_precompiled_swap_ops=True,
426545
num_blocks_step1=2,
427546
num_blocks_step2=1,
428547
),
429548
dict(
430549
testcase_name="_zero_token_step2",
550+
use_precompiled_swap_ops=False,
551+
num_blocks_step1=2,
552+
num_blocks_step2=0,
553+
),
554+
dict(
555+
testcase_name="_zero_token_step2_bucketed_precompiled",
556+
use_precompiled_swap_ops=True,
431557
num_blocks_step1=2,
432558
num_blocks_step2=0,
433559
),
434560
)
435561
def test_tpu_connector_multi_step_save(
436562
self,
563+
use_precompiled_swap_ops: bool,
437564
num_blocks_step1: int,
438565
num_blocks_step2: int,
439566
):
440567
"""
441568
Tests that the TPUConnectorWorker correctly saves the KV cache in multiple
442569
steps, respecting the save watermark (skip_leading_tokens).
443570
"""
571+
os.environ[
572+
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"
444573
num_tokens_step1 = num_blocks_step1 * self.block_size
445574
num_tokens_step2 = num_blocks_step2 * self.block_size
446575
logger.info(
@@ -589,31 +718,64 @@ def test_tpu_connector_multi_step_save(
589718
@parameterized.named_parameters(
590719
dict(
591720
testcase_name="_full_load_jax",
721+
use_precompiled_swap_ops=False,
722+
swap_op_type="jax",
723+
num_matched_blocks=4,
724+
num_computed_blocks=0,
725+
),
726+
dict(
727+
testcase_name="_full_load_jax_precompiled",
728+
use_precompiled_swap_ops=True,
592729
swap_op_type="jax",
593730
num_matched_blocks=4,
594731
num_computed_blocks=0,
595732
),
596733
dict(
597734
testcase_name="_delta_load_jax",
735+
use_precompiled_swap_ops=False,
736+
swap_op_type="jax",
737+
num_matched_blocks=4,
738+
num_computed_blocks=1,
739+
),
740+
dict(
741+
testcase_name="_delta_load_jax_precompiled",
742+
use_precompiled_swap_ops=True,
598743
swap_op_type="jax",
599744
num_matched_blocks=4,
600745
num_computed_blocks=1,
601746
),
602747
dict(
603748
testcase_name="_delta_load_pallas",
749+
use_precompiled_swap_ops=False,
750+
swap_op_type="pallas",
751+
num_matched_blocks=4,
752+
num_computed_blocks=1,
753+
),
754+
dict(
755+
testcase_name="_delta_load_pallas_precompiled",
756+
use_precompiled_swap_ops=True,
604757
swap_op_type="pallas",
605758
num_matched_blocks=4,
606759
num_computed_blocks=1,
607760
),
608761
dict(
609762
testcase_name="_no_load_jax",
763+
use_precompiled_swap_ops=False,
764+
swap_op_type="jax",
765+
num_matched_blocks=1,
766+
num_computed_blocks=1,
767+
),
768+
dict(
769+
testcase_name="_no_load_jax_precompiled",
770+
use_precompiled_swap_ops=True,
610771
swap_op_type="jax",
611772
num_matched_blocks=1,
612773
num_computed_blocks=1,
613774
),
614775
)
615776
def test_tpu_connector_load(
616777
self,
778+
use_precompiled_swap_ops: bool,
617779
swap_op_type: str,
618780
num_matched_blocks: int,
619781
num_computed_blocks: int = 0,
@@ -654,6 +816,8 @@ def test_tpu_connector_load(
654816
- Assert that the parts of the destination cache that should not have
655817
been touched remain zero.
656818
"""
819+
os.environ[
820+
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"
657821
num_matched_tokens = num_matched_blocks * self.block_size
658822
num_computed_tokens = num_computed_blocks * self.block_size
659823
if num_matched_blocks > self.num_blocks:

0 commit comments

Comments
 (0)