|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import functools |
| 4 | +import os |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import jax |
| 8 | +import jax.numpy as jnp |
| 9 | +import numpy as np |
| 10 | +from absl.testing import parameterized |
| 11 | +from jax._src import compilation_cache as cc |
| 12 | +from jax._src import test_util as jtu |
| 13 | +from jax.sharding import Mesh, NamedSharding, PartitionSpec |
| 14 | +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole |
| 15 | + |
| 16 | +from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend |
| 17 | +from tpu_inference.distributed.tpu_connector_local import \ |
| 18 | + TPUConnector as CPUOffloadingConnector |
| 19 | +from tpu_inference.logger import init_logger |
| 20 | +from tpu_inference.runner.tpu_jax_runner import TPUModelRunner |
| 21 | + |
| 22 | +logger = init_logger(__name__) |
| 23 | + |
| 24 | +_DEFAULT_BLOCK_SIZE = 64 |
| 25 | + |
| 26 | + |
| 27 | +class MockTPUModelRunner(TPUModelRunner): |
| 28 | + """A mock TPUModelRunner for testing purposes.""" |
| 29 | + |
| 30 | + def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): |
| 31 | + self.kv_caches = kv_caches |
| 32 | + self.mesh = mesh |
| 33 | + self.model_config = None |
| 34 | + self.sampler = None |
| 35 | + |
| 36 | + def get_kv_cache_layout(self): |
| 37 | + return "NHD" |
| 38 | + |
| 39 | + |
| 40 | +class MockVllmConfig: |
| 41 | + |
| 42 | + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): |
| 43 | + self.model_config = self.Model() |
| 44 | + self.cache_config = self.Cache(block_size) |
| 45 | + self.kv_transfer_config = self.KVTransfer() |
| 46 | + |
| 47 | + class Model: |
| 48 | + model = "test-model" |
| 49 | + |
| 50 | + class Cache: |
| 51 | + |
| 52 | + def __init__(self, block_size): |
| 53 | + self.block_size = block_size |
| 54 | + |
| 55 | + class KVTransfer: |
| 56 | + kv_ip = "localhost" |
| 57 | + kv_port = 9999 |
| 58 | + |
| 59 | + |
| 60 | +class TestHostOffloadingPrecompile(jtu.JaxTestCase): |
| 61 | + """Test the host offloading precompilation and related functionalities.""" |
| 62 | + |
| 63 | + def setUp(self): |
| 64 | + super().setUp() |
| 65 | + self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) |
| 66 | + self.num_layers = 2 |
| 67 | + self.num_blocks = 128 # Increased for larger tests |
| 68 | + self.block_size = self.vllm_config.cache_config.block_size |
| 69 | + self.num_heads = 8 |
| 70 | + self.head_size = 128 |
| 71 | + self.mesh = self.create_mesh((1, 8), ("data", "model")) |
| 72 | + if self.mesh is None: |
| 73 | + self.skipTest("Cannot create mesh. Must be run on a TPU node.") |
| 74 | + return |
| 75 | + |
| 76 | + # Define cache properties |
| 77 | + self.cache_shape = ( |
| 78 | + self.num_blocks, |
| 79 | + self.block_size, |
| 80 | + self.num_heads, |
| 81 | + 2, |
| 82 | + self.head_size, |
| 83 | + ) |
| 84 | + self.cache_dtype = jnp.bfloat16 |
| 85 | + partition_spec = PartitionSpec(None, None, "model") |
| 86 | + self.device_sharding = NamedSharding(self.mesh, partition_spec) |
| 87 | + |
| 88 | + def tearDown(self): |
| 89 | + super().tearDown() |
| 90 | + cc.reset_cache() |
| 91 | + |
| 92 | + def create_mesh(self, axis_shapes, axis_names): |
| 93 | + """Creates a JAX device mesh with the default device order.""" |
| 94 | + try: |
| 95 | + num_required_devices = np.prod(axis_shapes) |
| 96 | + devices = np.array(jax.devices()) |
| 97 | + if len(devices) < num_required_devices: |
| 98 | + self.skipTest( |
| 99 | + f"Not enough devices to create mesh of shape {axis_shapes}." |
| 100 | + ) |
| 101 | + device_array = devices[:num_required_devices].reshape(axis_shapes) |
| 102 | + return jax.sharding.Mesh(device_array, axis_names) |
| 103 | + except RuntimeError: |
| 104 | + return None |
| 105 | + |
| 106 | + def _create_connector(self, swap_op_type: str = "jax"): |
| 107 | + # Clean the singleton backend instance before each test |
| 108 | + LocalCPUBackend._instance = None |
| 109 | + LocalCPUBackend._initialized = False |
| 110 | + |
| 111 | + os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type |
| 112 | + connector = CPUOffloadingConnector(self.vllm_config, |
| 113 | + KVConnectorRole.WORKER) |
| 114 | + worker = connector.connector_worker |
| 115 | + assert worker is not None |
| 116 | + |
| 117 | + @functools.partial(jax.jit, out_shardings=self.device_sharding) |
| 118 | + def create_on_device(key): |
| 119 | + return jax.random.uniform(key, |
| 120 | + shape=self.cache_shape, |
| 121 | + dtype=self.cache_dtype) |
| 122 | + |
| 123 | + source_kv_cache = [ |
| 124 | + create_on_device(jax.random.key(i)) for i in range(self.num_layers) |
| 125 | + ] |
| 126 | + jax.block_until_ready(source_kv_cache) |
| 127 | + |
| 128 | + mock_runner = MockTPUModelRunner(kv_caches=source_kv_cache, |
| 129 | + mesh=self.mesh) |
| 130 | + worker.register_runner(mock_runner) |
| 131 | + return connector |
| 132 | + |
| 133 | + @parameterized.named_parameters( |
| 134 | + dict(testcase_name="_zero_blocks", num_blocks=0, expected_buckets=[]), |
| 135 | + dict(testcase_name="_one_block", num_blocks=1, expected_buckets=[1]), |
| 136 | + dict(testcase_name="_five_blocks", |
| 137 | + num_blocks=5, |
| 138 | + expected_buckets=[4, 1]), |
| 139 | + dict(testcase_name="_sixteen_blocks", |
| 140 | + num_blocks=16, |
| 141 | + expected_buckets=[16]), |
| 142 | + dict(testcase_name="_seventeen_blocks", |
| 143 | + num_blocks=17, |
| 144 | + expected_buckets=[16, 1]), |
| 145 | + dict(testcase_name="_twenty_three_blocks", |
| 146 | + num_blocks=23, |
| 147 | + expected_buckets=[16, 4, 2, 1]), |
| 148 | + dict(testcase_name="_thirty_two_blocks", |
| 149 | + num_blocks=32, |
| 150 | + expected_buckets=[16, 16]), |
| 151 | + dict(testcase_name="_large_number_blocks", |
| 152 | + num_blocks=100, |
| 153 | + expected_buckets=[16, 16, 16, 16, 16, 16, 4]), |
| 154 | + ) |
| 155 | + def test_decompose_into_buckets(self, num_blocks: int, |
| 156 | + expected_buckets: List[int]): |
| 157 | + """ |
| 158 | + Tests the _decompose_into_buckets function for correct greedy decomposition. |
| 159 | + """ |
| 160 | + os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" |
| 161 | + connector = self._create_connector() |
| 162 | + worker = connector.connector_worker |
| 163 | + self.assertEqual(worker._decompose_into_buckets(num_blocks), |
| 164 | + expected_buckets) |
| 165 | + logger.info( |
| 166 | + f"Decomposition for {num_blocks} blocks: {worker._decompose_into_buckets(num_blocks)} matched expected: {expected_buckets}" |
| 167 | + ) |
| 168 | + |
| 169 | + @parameterized.named_parameters( |
| 170 | + dict(testcase_name="_jax", swap_op_type="jax"), |
| 171 | + dict(testcase_name="_pallas", swap_op_type="pallas"), |
| 172 | + ) |
| 173 | + def test_precompile_run_success(self, swap_op_type: str): |
| 174 | + """ |
| 175 | + Tests that _precompile_kv_swap_operations runs without errors and |
| 176 | + modifies the cache content. |
| 177 | + """ |
| 178 | + # Unset skip flag to allow precompilation to run |
| 179 | + os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" |
| 180 | + connector = self._create_connector(swap_op_type=swap_op_type) |
| 181 | + worker = connector.connector_worker |
| 182 | + |
| 183 | + # Keep a copy of the original cache content on the host |
| 184 | + original_cache_host = [ |
| 185 | + np.array(cache) for cache in worker.runner.kv_caches |
| 186 | + ] |
| 187 | + |
| 188 | + worker._precompile_kv_swap_operations() |
| 189 | + |
| 190 | + # Fetch the new cache content to the host |
| 191 | + new_cache_host = [np.array(cache) for cache in worker.runner.kv_caches] |
| 192 | + |
| 193 | + self.assertTrue( |
| 194 | + all( |
| 195 | + np.array_equal(orig, new) |
| 196 | + for orig, new in zip(original_cache_host, new_cache_host)), |
| 197 | + "Cache content should not have changed after precompilation.", |
| 198 | + ) |
0 commit comments