|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import numpy as np |
| 4 | +from absl.testing import absltest, parameterized |
| 5 | +from jax._src import test_util as jtu |
| 6 | + |
| 7 | +import tpu_inference.kernels.mla.v1.kernel as mla |
| 8 | +from tpu_inference.kernels.ragged_paged_attention.v3.util import ( |
| 9 | + align_to, cdiv, get_dtype_packing) |
| 10 | + |
| 11 | +jax.config.parse_flags_with_absl() |
| 12 | + |
| 13 | + |
| 14 | +@jtu.with_config(jax_numpy_dtype_promotion="standard") |
| 15 | +class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase): |
| 16 | + |
| 17 | + def _test_mla_ragged_paged_attention( |
| 18 | + self, |
| 19 | + seq_lens, # List[(q_len, kv_len)] |
| 20 | + num_heads, |
| 21 | + lkv_dim, |
| 22 | + r_dim, |
| 23 | + page_size, |
| 24 | + q_dtype, |
| 25 | + kv_dtype, |
| 26 | + num_pages, |
| 27 | + *, |
| 28 | + num_kv_pages_per_block=8, |
| 29 | + num_queries_per_block=8, |
| 30 | + vmem_limit_bytes=100 * 1024 * 1024, |
| 31 | + sm_scale=1.0, |
| 32 | + sliding_window: int | None = None, |
| 33 | + soft_cap: float | None = None, |
| 34 | + ): |
| 35 | + if not jtu.is_device_tpu_at_least(version=4): |
| 36 | + self.skipTest("Expect TPUv4+") |
| 37 | + rng = np.random.default_rng(1234) |
| 38 | + |
| 39 | + def gen_random(shape, dtype): |
| 40 | + return jnp.array(rng.random(size=shape, |
| 41 | + dtype=np.float32)).astype(dtype) |
| 42 | + |
| 43 | + padded_r_dim = align_to(r_dim, 128) |
| 44 | + padded_lkv_dim = align_to(lkv_dim, 128) |
| 45 | + packing = get_dtype_packing(kv_dtype) |
| 46 | + q_lens = [s[0] for s in seq_lens] |
| 47 | + kv_lens_list = [s[1] for s in seq_lens] |
| 48 | + total_q_len = sum(q_lens) |
| 49 | + cu_q_lens_list = [0] |
| 50 | + for q_len in q_lens: |
| 51 | + cu_q_lens_list.append(cu_q_lens_list[-1] + q_len) |
| 52 | + |
| 53 | + max_kv_len = max(kv_lens_list) if kv_lens_list else 0 |
| 54 | + pages_per_seq = cdiv(max_kv_len, page_size) |
| 55 | + |
| 56 | + page_indices_list = [] |
| 57 | + page_count = 0 |
| 58 | + for kv_len in kv_lens_list: |
| 59 | + num_seq_pages = cdiv(kv_len, page_size) |
| 60 | + indices = list(range(page_count, page_count + num_seq_pages)) |
| 61 | + page_indices_list.extend(indices + [-1] * |
| 62 | + (pages_per_seq - num_seq_pages)) |
| 63 | + page_count += num_seq_pages |
| 64 | + |
| 65 | + total_num_pages = max(num_pages, page_count) |
| 66 | + |
| 67 | + ql_nope = gen_random((total_q_len, num_heads, lkv_dim), q_dtype) |
| 68 | + q_pe = gen_random((total_q_len, num_heads, r_dim), q_dtype) |
| 69 | + new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype) |
| 70 | + new_k_pe = gen_random((total_q_len, r_dim), kv_dtype) |
| 71 | + |
| 72 | + cache_kv_c = gen_random( |
| 73 | + (total_num_pages, page_size // packing, packing, padded_lkv_dim), |
| 74 | + kv_dtype, |
| 75 | + ) |
| 76 | + cache_k_pe = gen_random( |
| 77 | + (total_num_pages, page_size // packing, packing, padded_r_dim), |
| 78 | + kv_dtype) |
| 79 | + kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32) |
| 80 | + page_indices = jnp.array(page_indices_list, dtype=jnp.int32) |
| 81 | + cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32) |
| 82 | + distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32) |
| 83 | + |
| 84 | + ql_nope_for_kernel = ql_nope.copy() |
| 85 | + q_pe_for_kernel = q_pe.copy() |
| 86 | + |
| 87 | + expected_out, expected_updated_kv_c, expeceted_updated_k_pe = ( |
| 88 | + mla.ref_mla_ragged_paged_attention( |
| 89 | + ql_nope, |
| 90 | + q_pe, |
| 91 | + new_kv_c, |
| 92 | + new_k_pe, |
| 93 | + cache_kv_c.copy(), |
| 94 | + cache_k_pe.copy(), |
| 95 | + kv_lens, |
| 96 | + page_indices, |
| 97 | + cu_q_lens, |
| 98 | + distribution, |
| 99 | + sm_scale=sm_scale, |
| 100 | + sliding_window=sliding_window, |
| 101 | + soft_cap=soft_cap, |
| 102 | + )) |
| 103 | + |
| 104 | + kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = ( |
| 105 | + mla.mla_ragged_paged_attention( |
| 106 | + ql_nope_for_kernel, |
| 107 | + q_pe_for_kernel, |
| 108 | + new_kv_c, |
| 109 | + new_k_pe, |
| 110 | + cache_kv_c.copy(), |
| 111 | + cache_k_pe.copy(), |
| 112 | + kv_lens, |
| 113 | + page_indices, |
| 114 | + cu_q_lens, |
| 115 | + distribution, |
| 116 | + sm_scale=sm_scale, |
| 117 | + sliding_window=sliding_window, |
| 118 | + soft_cap=soft_cap, |
| 119 | + num_kv_pages_per_block=num_kv_pages_per_block, |
| 120 | + num_queries_per_block=num_queries_per_block, |
| 121 | + vmem_limit_bytes=vmem_limit_bytes, |
| 122 | + )) |
| 123 | + |
| 124 | + self.assertEqual(expected_out.shape, |
| 125 | + (total_q_len, num_heads, padded_lkv_dim)) |
| 126 | + self.assertEqual( |
| 127 | + expected_updated_kv_c.shape, |
| 128 | + (total_num_pages, page_size // packing, packing, padded_lkv_dim), |
| 129 | + ) |
| 130 | + self.assertEqual( |
| 131 | + expeceted_updated_k_pe.shape, |
| 132 | + (total_num_pages, page_size // packing, packing, padded_r_dim), |
| 133 | + ) |
| 134 | + self.assertEqual(expected_out.dtype, kv_dtype) |
| 135 | + self.assertEqual(expected_updated_kv_c.dtype, kv_dtype) |
| 136 | + self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype) |
| 137 | + |
| 138 | + self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2) |
| 139 | + self.assertAllClose(expected_updated_kv_c, |
| 140 | + kernel_updated_kv_c, |
| 141 | + atol=0.2, |
| 142 | + rtol=0.2) |
| 143 | + self.assertAllClose(expeceted_updated_k_pe, |
| 144 | + kernel_updated_k_pe, |
| 145 | + atol=0.2, |
| 146 | + rtol=0.2) |
| 147 | + |
| 148 | + def test_ragged_paged_attention_basic(self): |
| 149 | + dtype = jnp.bfloat16 |
| 150 | + seq_lens = [(192, 328), (128, 180), (64, 255)] |
| 151 | + num_heads = 128 |
| 152 | + lkv_dim = 512 |
| 153 | + r_dim = 64 |
| 154 | + page_size = 16 |
| 155 | + num_pages = 1000 |
| 156 | + |
| 157 | + self._test_mla_ragged_paged_attention( |
| 158 | + seq_lens, |
| 159 | + num_heads, |
| 160 | + lkv_dim, |
| 161 | + r_dim, |
| 162 | + page_size, |
| 163 | + dtype, |
| 164 | + dtype, |
| 165 | + num_pages, |
| 166 | + ) |
| 167 | + |
| 168 | + @parameterized.product(dtype=[jnp.bfloat16], ) |
| 169 | + def test_ragged_paged_attention_decode_only(self, dtype): |
| 170 | + seq_lens = [ |
| 171 | + (1, 18), |
| 172 | + (1, 129), |
| 173 | + (1, 597), |
| 174 | + (1, 122), |
| 175 | + (1, 64), |
| 176 | + (1, 322), |
| 177 | + (1, 463), |
| 178 | + (1, 181), |
| 179 | + (1, 1107), |
| 180 | + (1, 123), |
| 181 | + (1, 31), |
| 182 | + (1, 18), |
| 183 | + (1, 1229), |
| 184 | + (1, 229), |
| 185 | + (1, 87), |
| 186 | + (1, 1328), |
| 187 | + ] |
| 188 | + num_heads = 128 |
| 189 | + lkv_dim = 512 |
| 190 | + r_dim = 64 |
| 191 | + page_size = 16 |
| 192 | + num_pages = 1000 |
| 193 | + |
| 194 | + self._test_mla_ragged_paged_attention( |
| 195 | + seq_lens, |
| 196 | + num_heads, |
| 197 | + lkv_dim, |
| 198 | + r_dim, |
| 199 | + page_size, |
| 200 | + dtype, |
| 201 | + dtype, |
| 202 | + num_pages, |
| 203 | + ) |
| 204 | + |
| 205 | + @parameterized.product(dtype=[jnp.bfloat16], ) |
| 206 | + def test_ragged_paged_attention_prefill_only(self, dtype): |
| 207 | + seq_lens = [ |
| 208 | + (5, 18), |
| 209 | + (15, 129), |
| 210 | + (120, 597), |
| 211 | + (100, 122), |
| 212 | + (21, 64), |
| 213 | + (32, 322), |
| 214 | + (251, 463), |
| 215 | + (40, 181), |
| 216 | + (64, 1107), |
| 217 | + (99, 123), |
| 218 | + (10, 31), |
| 219 | + (5, 18), |
| 220 | + (3, 1229), |
| 221 | + (120, 229), |
| 222 | + (9, 87), |
| 223 | + (2, 1328), |
| 224 | + ] |
| 225 | + num_heads = 128 |
| 226 | + lkv_dim = 512 |
| 227 | + r_dim = 64 |
| 228 | + page_size = 16 |
| 229 | + num_pages = 1000 |
| 230 | + |
| 231 | + self._test_mla_ragged_paged_attention( |
| 232 | + seq_lens, |
| 233 | + num_heads, |
| 234 | + lkv_dim, |
| 235 | + r_dim, |
| 236 | + page_size, |
| 237 | + dtype, |
| 238 | + dtype, |
| 239 | + num_pages, |
| 240 | + ) |
| 241 | + |
| 242 | + @parameterized.product(dtype=[jnp.bfloat16], ) |
| 243 | + def test_ragged_paged_attention_mixed(self, dtype): |
| 244 | + seq_lens = [ |
| 245 | + (5, 18), |
| 246 | + (1, 129), |
| 247 | + (120, 597), |
| 248 | + (1, 122), |
| 249 | + (1, 64), |
| 250 | + (32, 322), |
| 251 | + (251, 463), |
| 252 | + (1, 181), |
| 253 | + (1, 1107), |
| 254 | + (99, 123), |
| 255 | + (1, 31), |
| 256 | + (5, 18), |
| 257 | + (3, 1229), |
| 258 | + (117, 229), |
| 259 | + (1, 87), |
| 260 | + (1, 1328), |
| 261 | + ] |
| 262 | + num_heads = 128 |
| 263 | + lkv_dim = 512 |
| 264 | + r_dim = 64 |
| 265 | + page_size = 16 |
| 266 | + num_pages = 1000 |
| 267 | + |
| 268 | + self._test_mla_ragged_paged_attention( |
| 269 | + seq_lens, |
| 270 | + num_heads, |
| 271 | + lkv_dim, |
| 272 | + r_dim, |
| 273 | + page_size, |
| 274 | + dtype, |
| 275 | + dtype, |
| 276 | + num_pages, |
| 277 | + ) |
| 278 | + |
| 279 | + @parameterized.product(sliding_window=[None, 5, 128], ) |
| 280 | + def test_ragged_paged_attention_sliding_window( |
| 281 | + self, |
| 282 | + sliding_window: int | None, |
| 283 | + ): |
| 284 | + num_seqs = 5 |
| 285 | + num_heads = 128 |
| 286 | + lkv_dim = 512 |
| 287 | + r_dim = 64 |
| 288 | + dtype = jnp.float32 |
| 289 | + rng = np.random.default_rng(1234) |
| 290 | + q_lens = rng.integers(1, 100, num_seqs) |
| 291 | + kv_lens = q_lens + rng.integers(0, 50, num_seqs) |
| 292 | + seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist())) |
| 293 | + page_size = 16 |
| 294 | + num_pages = 1000 |
| 295 | + |
| 296 | + self._test_mla_ragged_paged_attention( |
| 297 | + seq_lens, |
| 298 | + num_heads, |
| 299 | + lkv_dim, |
| 300 | + r_dim, |
| 301 | + page_size, |
| 302 | + dtype, |
| 303 | + dtype, |
| 304 | + num_pages, |
| 305 | + sliding_window=sliding_window, |
| 306 | + ) |
| 307 | + |
| 308 | + @parameterized.product(soft_cap=[None, 50.0], ) |
| 309 | + def test_ragged_paged_attention_logit_soft_capping( |
| 310 | + self, |
| 311 | + soft_cap: float | None, |
| 312 | + ): |
| 313 | + num_heads = 128 |
| 314 | + num_seqs = 2 |
| 315 | + dtype = jnp.float32 |
| 316 | + rng = np.random.default_rng(1234) |
| 317 | + q_lens = rng.integers(1, 100, num_seqs) |
| 318 | + kv_lens = q_lens + rng.integers(0, 50, num_seqs) |
| 319 | + seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist())) |
| 320 | + lkv_dim = 512 |
| 321 | + r_dim = 64 |
| 322 | + page_size = 16 |
| 323 | + num_pages = 1000 |
| 324 | + |
| 325 | + self._test_mla_ragged_paged_attention( |
| 326 | + seq_lens, |
| 327 | + num_heads, |
| 328 | + lkv_dim, |
| 329 | + r_dim, |
| 330 | + page_size, |
| 331 | + dtype, |
| 332 | + dtype, |
| 333 | + num_pages, |
| 334 | + soft_cap=soft_cap, |
| 335 | + ) |
| 336 | + |
| 337 | + def test_ragged_paged_attention_sliding_window_should_be_positive(self): |
| 338 | + dtype = jnp.float32 |
| 339 | + seq_lens = [(192, 328), (128, 180), (64, 255)] |
| 340 | + num_heads = 128 |
| 341 | + lkv_dim = 512 |
| 342 | + r_dim = 64 |
| 343 | + page_size = 16 |
| 344 | + num_pages = 1000 |
| 345 | + |
| 346 | + with self.assertRaisesRegex(ValueError, "must be positive"): |
| 347 | + self._test_mla_ragged_paged_attention( |
| 348 | + seq_lens, |
| 349 | + num_heads, |
| 350 | + lkv_dim, |
| 351 | + r_dim, |
| 352 | + page_size, |
| 353 | + dtype, |
| 354 | + dtype, |
| 355 | + num_pages, |
| 356 | + sliding_window=0, |
| 357 | + ) |
| 358 | + |
| 359 | + with self.assertRaisesRegex(ValueError, "must be positive"): |
| 360 | + self._test_mla_ragged_paged_attention( |
| 361 | + seq_lens, |
| 362 | + num_heads, |
| 363 | + lkv_dim, |
| 364 | + r_dim, |
| 365 | + page_size, |
| 366 | + dtype, |
| 367 | + dtype, |
| 368 | + num_pages, |
| 369 | + sliding_window=-1, |
| 370 | + ) |
| 371 | + |
| 372 | + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): |
| 373 | + dtype = jnp.float32 |
| 374 | + seq_lens = [(192, 328), (128, 180), (64, 255)] |
| 375 | + num_heads = 128 |
| 376 | + lkv_dim = 512 |
| 377 | + r_dim = 64 |
| 378 | + page_size = 16 |
| 379 | + num_pages = 1000 |
| 380 | + |
| 381 | + with self.assertRaisesRegex(ValueError, "must not be 0.0"): |
| 382 | + self._test_mla_ragged_paged_attention( |
| 383 | + seq_lens, |
| 384 | + num_heads, |
| 385 | + lkv_dim, |
| 386 | + r_dim, |
| 387 | + page_size, |
| 388 | + dtype, |
| 389 | + dtype, |
| 390 | + num_pages, |
| 391 | + soft_cap=0.0, |
| 392 | + ) |
| 393 | + |
| 394 | + |
| 395 | +if __name__ == "__main__": |
| 396 | + absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments