Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/gke/pod_tpu_host_offload_unit_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
apiVersion: v1
kind: Pod
metadata:
name: tpu-job-host-offload-unit-tests
# This pod runs the distributed unit tests for the TPUConnector
# and other related functionalities. It executes all tests found in the
# tests/distributed/ directory using pytest.
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice.
containers:
- name: tpu-job
image: gcr.io/gke-shared-ai-dev/tpu-inference:cpu-offload
imagePullPolicy: Always
command:
- /bin/bash
- -c
- "pytest -sv tests/distributed/host_offloading_precompile_test.py"
# - "pytest -sv tests/distributed/cpu_offloading_worker_test.py"
# - "pytest -sv tests/distributed/cpu_offloading_cache_util_test.py"
# - "pytest -sv tests/distributed/host_offloading_accuracy_test.py"
# - "pytest -sv tests/distributed/local_cpu_backend_test.py"
# - "pytest -sv tests/distributed/host_offloading_precompile_test.py"
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
resources:
requests:
google.com/tpu: 8
limits:
google.com/tpu: 8
166 changes: 165 additions & 1 deletion tests/distributed/cpu_offloading_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,32 @@ def _verify_saved_data(
@parameterized.named_parameters(
dict(
testcase_name="_prefill_no_skip_save_2_drop_jax",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=2,
),
dict(
testcase_name="_prefill_no_skip_save_2_drop_jax_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=2,
),
dict(
testcase_name="_prefill_no_skip_save_2_drop_pallas",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=2,
swap_op_type="pallas",
),
dict(
testcase_name="_prefill_no_skip_save_2_drop_pallas_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
Expand All @@ -219,13 +238,32 @@ def _verify_saved_data(
# block and assign 3 blocks to save.
dict(
testcase_name="_prefill_no_skip_save_2_pad_jax",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=3,
),
dict(
testcase_name="_prefill_no_skip_save_2_pad_jax_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=3,
),
dict(
testcase_name="_prefill_no_skip_save_2_pad_pallas",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=3,
swap_op_type="pallas",
),
dict(
testcase_name="_prefill_no_skip_save_2_pad_pallas_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=0,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
Expand All @@ -234,27 +272,65 @@ def _verify_saved_data(
),
dict(
testcase_name="_prefill_skip_2_save_2_drop",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
num_blocks_to_save=2,
),
dict(
testcase_name="_prefill_skip_2_save_2_drop_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
num_blocks_to_save=2,
),
dict(
testcase_name="_prefill_skip_2_save_2_pad",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
num_blocks_to_save=3,
),
dict(
testcase_name="_prefill_skip_2_save_2_pad_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10,
num_blocks_to_save=3,
),
dict(
testcase_name="_decode_skip_3_save_1",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4,
num_blocks_to_save=1,
),
dict(
testcase_name="_decode_skip_3_save_1_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 4,
num_blocks_to_save=1,
),
dict(
testcase_name="_no_save",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=0,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_blocks_to_save=0,
is_final_save=False,
skip_save=False,
),
dict(
testcase_name="_no_save_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=0,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
Expand All @@ -264,6 +340,17 @@ def _verify_saved_data(
),
dict(
testcase_name="_final_save_save_1_drop",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10,
num_blocks_to_save=1,
is_final_save=True,
skip_save=False,
),
dict(
testcase_name="_final_save_save_1_drop_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=_DEFAULT_BLOCK_SIZE,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10,
Expand All @@ -273,6 +360,17 @@ def _verify_saved_data(
),
dict(
testcase_name="_final_save_save_1_pad",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
num_blocks_to_save=1,
is_final_save=True,
skip_save=False,
),
dict(
testcase_name="_final_save_save_1_pad_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=10,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10,
Expand All @@ -282,6 +380,17 @@ def _verify_saved_data(
),
dict(
testcase_name="_final_save_without_data",
use_precompiled_swap_ops=False,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=0,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_blocks_to_save=0,
is_final_save=True,
skip_save=True,
),
dict(
testcase_name="_final_save_without_data_precompiled",
use_precompiled_swap_ops=True,
num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2,
num_tokens_to_save=0,
num_total_tokens=_DEFAULT_BLOCK_SIZE * 2,
Expand All @@ -292,6 +401,7 @@ def _verify_saved_data(
)
def test_tpu_connector_save(
self,
use_precompiled_swap_ops: bool,
num_skip_leading_tokens: int,
num_tokens_to_save: int,
num_total_tokens: int,
Expand All @@ -300,6 +410,8 @@ def test_tpu_connector_save(
skip_save: bool = False,
swap_op_type: str = "jax",
):
os.environ[
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"

# Prepare and Execute Save
total_token_ids = list(range(num_total_tokens))
Expand Down Expand Up @@ -422,25 +534,42 @@ def test_tpu_connector_save(

@parameterized.named_parameters(
dict(
testcase_name="_2_steps",
testcase_name="_2_steps_nobucket",
use_precompiled_swap_ops=False,
num_blocks_step1=2,
num_blocks_step2=1,
),
dict(
testcase_name="_2_steps_bucketed_precompiled",
use_precompiled_swap_ops=True,
num_blocks_step1=2,
num_blocks_step2=1,
),
dict(
testcase_name="_zero_token_step2",
use_precompiled_swap_ops=False,
num_blocks_step1=2,
num_blocks_step2=0,
),
dict(
testcase_name="_zero_token_step2_bucketed_precompiled",
use_precompiled_swap_ops=True,
num_blocks_step1=2,
num_blocks_step2=0,
),
)
def test_tpu_connector_multi_step_save(
self,
use_precompiled_swap_ops: bool,
num_blocks_step1: int,
num_blocks_step2: int,
):
"""
Tests that the TPUConnectorWorker correctly saves the KV cache in multiple
steps, respecting the save watermark (skip_leading_tokens).
"""
os.environ[
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"
num_tokens_step1 = num_blocks_step1 * self.block_size
num_tokens_step2 = num_blocks_step2 * self.block_size
logger.info(
Expand Down Expand Up @@ -589,31 +718,64 @@ def test_tpu_connector_multi_step_save(
@parameterized.named_parameters(
dict(
testcase_name="_full_load_jax",
use_precompiled_swap_ops=False,
swap_op_type="jax",
num_matched_blocks=4,
num_computed_blocks=0,
),
dict(
testcase_name="_full_load_jax_precompiled",
use_precompiled_swap_ops=True,
swap_op_type="jax",
num_matched_blocks=4,
num_computed_blocks=0,
),
dict(
testcase_name="_delta_load_jax",
use_precompiled_swap_ops=False,
swap_op_type="jax",
num_matched_blocks=4,
num_computed_blocks=1,
),
dict(
testcase_name="_delta_load_jax_precompiled",
use_precompiled_swap_ops=True,
swap_op_type="jax",
num_matched_blocks=4,
num_computed_blocks=1,
),
dict(
testcase_name="_delta_load_pallas",
use_precompiled_swap_ops=False,
swap_op_type="pallas",
num_matched_blocks=4,
num_computed_blocks=1,
),
dict(
testcase_name="_delta_load_pallas_precompiled",
use_precompiled_swap_ops=True,
swap_op_type="pallas",
num_matched_blocks=4,
num_computed_blocks=1,
),
dict(
testcase_name="_no_load_jax",
use_precompiled_swap_ops=False,
swap_op_type="jax",
num_matched_blocks=1,
num_computed_blocks=1,
),
dict(
testcase_name="_no_load_jax_precompiled",
use_precompiled_swap_ops=True,
swap_op_type="jax",
num_matched_blocks=1,
num_computed_blocks=1,
),
)
def test_tpu_connector_load(
self,
use_precompiled_swap_ops: bool,
swap_op_type: str,
num_matched_blocks: int,
num_computed_blocks: int = 0,
Expand Down Expand Up @@ -654,6 +816,8 @@ def test_tpu_connector_load(
- Assert that the parts of the destination cache that should not have
been touched remain zero.
"""
os.environ[
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1"
num_matched_tokens = num_matched_blocks * self.block_size
num_computed_tokens = num_computed_blocks * self.block_size
if num_matched_blocks > self.num_blocks:
Expand Down
Loading