Skip to content

Commit d1c5c0f

Browse files
committed
prepare precompile functions for tpu local connector
1. prepare precompile functions which will cycle through the load and save jitted functions 2. decompose the load and save util functions to be block buckets aligned 3. unit tests for the change Signed-off-by: Saikat Roychowdhury <saikat.royc85@gmail.com>
1 parent 205e474 commit d1c5c0f

File tree

2 files changed

+435
-10
lines changed

2 files changed

+435
-10
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import functools
4+
import os
5+
from typing import List
6+
7+
import jax
8+
import jax.numpy as jnp
9+
import numpy as np
10+
from absl.testing import parameterized
11+
from jax._src import compilation_cache as cc
12+
from jax._src import test_util as jtu
13+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
14+
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
15+
16+
from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend
17+
from tpu_inference.distributed.tpu_connector_local import \
18+
TPUConnector as CPUOffloadingConnector
19+
from tpu_inference.logger import init_logger
20+
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
21+
22+
logger = init_logger(__name__)
23+
24+
_DEFAULT_BLOCK_SIZE = 64
25+
26+
27+
class MockTPUModelRunner(TPUModelRunner):
28+
"""A mock TPUModelRunner for testing purposes."""
29+
30+
def __init__(self, kv_caches: List[jax.Array], mesh: Mesh):
31+
self.kv_caches = kv_caches
32+
self.mesh = mesh
33+
self.model_config = None
34+
self.sampler = None
35+
36+
def get_kv_cache_layout(self):
37+
return "NHD"
38+
39+
40+
class MockVllmConfig:
41+
42+
def __init__(self, block_size=_DEFAULT_BLOCK_SIZE):
43+
self.model_config = self.Model()
44+
self.cache_config = self.Cache(block_size)
45+
self.kv_transfer_config = self.KVTransfer()
46+
47+
class Model:
48+
model = "test-model"
49+
50+
class Cache:
51+
52+
def __init__(self, block_size):
53+
self.block_size = block_size
54+
55+
class KVTransfer:
56+
kv_ip = "localhost"
57+
kv_port = 9999
58+
59+
60+
class TestHostOffloadingPrecompile(jtu.JaxTestCase):
61+
"""Test the host offloading precompilation and related functionalities."""
62+
63+
def setUp(self):
64+
super().setUp()
65+
self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE)
66+
self.num_layers = 2
67+
self.num_blocks = 128 # Increased for larger tests
68+
self.block_size = self.vllm_config.cache_config.block_size
69+
self.num_heads = 8
70+
self.head_size = 128
71+
self.mesh = self.create_mesh((1, 8), ("data", "model"))
72+
if self.mesh is None:
73+
self.skipTest("Cannot create mesh. Must be run on a TPU node.")
74+
return
75+
76+
# Define cache properties
77+
self.cache_shape = (
78+
self.num_blocks,
79+
self.block_size,
80+
self.num_heads,
81+
2,
82+
self.head_size,
83+
)
84+
self.cache_dtype = jnp.bfloat16
85+
partition_spec = PartitionSpec(None, None, "model")
86+
self.device_sharding = NamedSharding(self.mesh, partition_spec)
87+
88+
def tearDown(self):
89+
super().tearDown()
90+
cc.reset_cache()
91+
92+
def create_mesh(self, axis_shapes, axis_names):
93+
"""Creates a JAX device mesh with the default device order."""
94+
try:
95+
num_required_devices = np.prod(axis_shapes)
96+
devices = np.array(jax.devices())
97+
if len(devices) < num_required_devices:
98+
self.skipTest(
99+
f"Not enough devices to create mesh of shape {axis_shapes}."
100+
)
101+
device_array = devices[:num_required_devices].reshape(axis_shapes)
102+
return jax.sharding.Mesh(device_array, axis_names)
103+
except RuntimeError:
104+
return None
105+
106+
def _create_connector(self, swap_op_type: str = "jax"):
107+
# Clean the singleton backend instance before each test
108+
LocalCPUBackend._instance = None
109+
LocalCPUBackend._initialized = False
110+
111+
os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type
112+
connector = CPUOffloadingConnector(self.vllm_config,
113+
KVConnectorRole.WORKER)
114+
worker = connector.connector_worker
115+
assert worker is not None
116+
117+
@functools.partial(jax.jit, out_shardings=self.device_sharding)
118+
def create_on_device(key):
119+
return jax.random.uniform(key,
120+
shape=self.cache_shape,
121+
dtype=self.cache_dtype)
122+
123+
source_kv_cache = [
124+
create_on_device(jax.random.key(i)) for i in range(self.num_layers)
125+
]
126+
jax.block_until_ready(source_kv_cache)
127+
128+
mock_runner = MockTPUModelRunner(kv_caches=source_kv_cache,
129+
mesh=self.mesh)
130+
worker.register_runner(mock_runner)
131+
return connector
132+
133+
@parameterized.named_parameters(
134+
dict(testcase_name="_zero_blocks", num_blocks=0, expected_buckets=[]),
135+
dict(testcase_name="_one_block", num_blocks=1, expected_buckets=[1]),
136+
dict(testcase_name="_five_blocks",
137+
num_blocks=5,
138+
expected_buckets=[4, 1]),
139+
dict(testcase_name="_sixteen_blocks",
140+
num_blocks=16,
141+
expected_buckets=[16]),
142+
dict(testcase_name="_seventeen_blocks",
143+
num_blocks=17,
144+
expected_buckets=[16, 1]),
145+
dict(testcase_name="_twenty_three_blocks",
146+
num_blocks=23,
147+
expected_buckets=[16, 4, 2, 1]),
148+
dict(testcase_name="_thirty_two_blocks",
149+
num_blocks=32,
150+
expected_buckets=[16, 16]),
151+
dict(testcase_name="_large_number_blocks",
152+
num_blocks=100,
153+
expected_buckets=[16, 16, 16, 16, 16, 16, 4]),
154+
)
155+
def test_decompose_into_buckets(self, num_blocks: int,
156+
expected_buckets: List[int]):
157+
"""
158+
Tests the _decompose_into_buckets function for correct greedy decomposition.
159+
"""
160+
os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0"
161+
connector = self._create_connector()
162+
worker = connector.connector_worker
163+
self.assertEqual(worker._decompose_into_buckets(num_blocks),
164+
expected_buckets)
165+
logger.info(
166+
f"Decomposition for {num_blocks} blocks: {worker._decompose_into_buckets(num_blocks)} matched expected: {expected_buckets}"
167+
)
168+
169+
@parameterized.named_parameters(
170+
dict(testcase_name="_jax", swap_op_type="jax"),
171+
dict(testcase_name="_pallas", swap_op_type="pallas"),
172+
)
173+
def test_precompile_run_success(self, swap_op_type: str):
174+
"""
175+
Tests that _precompile_kv_swap_operations runs without errors and
176+
modifies the cache content.
177+
"""
178+
# Unset skip flag to allow precompilation to run
179+
os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0"
180+
connector = self._create_connector(swap_op_type=swap_op_type)
181+
worker = connector.connector_worker
182+
183+
# Keep a copy of the original cache content on the host
184+
original_cache_host = [
185+
np.array(cache) for cache in worker.runner.kv_caches
186+
]
187+
188+
worker._precompile_kv_swap_operations()
189+
190+
# Fetch the new cache content to the host
191+
new_cache_host = [np.array(cache) for cache in worker.runner.kv_caches]
192+
193+
self.assertTrue(
194+
all(
195+
np.array_equal(orig, new)
196+
for orig, new in zip(original_cache_host, new_cache_host)),
197+
"Cache content should not have changed after precompilation.",
198+
)

0 commit comments

Comments
 (0)