Skip to content

Commit 4688f34

Browse files
yaochengjisierraisland
authored andcommitted
[Kernel] implement 1st version of data-movement friendly MLA kernel with no kv update fused (#1022)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 7865786 commit 4688f34

File tree

4 files changed

+1745
-0
lines changed

4 files changed

+1745
-0
lines changed

tests/kernels/mla_v1_test.py

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import numpy as np
4+
from absl.testing import absltest, parameterized
5+
from jax._src import test_util as jtu
6+
7+
import tpu_inference.kernels.mla.v1.kernel as mla
8+
from tpu_inference.kernels.ragged_paged_attention.v3.util import (
9+
align_to, cdiv, get_dtype_packing)
10+
11+
jax.config.parse_flags_with_absl()
12+
13+
14+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
15+
class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
16+
17+
def _test_mla_ragged_paged_attention(
18+
self,
19+
seq_lens, # List[(q_len, kv_len)]
20+
num_heads,
21+
lkv_dim,
22+
r_dim,
23+
page_size,
24+
q_dtype,
25+
kv_dtype,
26+
num_pages,
27+
*,
28+
num_kv_pages_per_block=8,
29+
num_queries_per_block=8,
30+
vmem_limit_bytes=100 * 1024 * 1024,
31+
sm_scale=1.0,
32+
sliding_window: int | None = None,
33+
soft_cap: float | None = None,
34+
):
35+
if not jtu.is_device_tpu_at_least(version=4):
36+
self.skipTest("Expect TPUv4+")
37+
rng = np.random.default_rng(1234)
38+
39+
def gen_random(shape, dtype):
40+
return jnp.array(rng.random(size=shape,
41+
dtype=np.float32)).astype(dtype)
42+
43+
padded_r_dim = align_to(r_dim, 128)
44+
padded_lkv_dim = align_to(lkv_dim, 128)
45+
packing = get_dtype_packing(kv_dtype)
46+
q_lens = [s[0] for s in seq_lens]
47+
kv_lens_list = [s[1] for s in seq_lens]
48+
total_q_len = sum(q_lens)
49+
cu_q_lens_list = [0]
50+
for q_len in q_lens:
51+
cu_q_lens_list.append(cu_q_lens_list[-1] + q_len)
52+
53+
max_kv_len = max(kv_lens_list) if kv_lens_list else 0
54+
pages_per_seq = cdiv(max_kv_len, page_size)
55+
56+
page_indices_list = []
57+
page_count = 0
58+
for kv_len in kv_lens_list:
59+
num_seq_pages = cdiv(kv_len, page_size)
60+
indices = list(range(page_count, page_count + num_seq_pages))
61+
page_indices_list.extend(indices + [-1] *
62+
(pages_per_seq - num_seq_pages))
63+
page_count += num_seq_pages
64+
65+
total_num_pages = max(num_pages, page_count)
66+
67+
ql_nope = gen_random((total_q_len, num_heads, lkv_dim), q_dtype)
68+
q_pe = gen_random((total_q_len, num_heads, r_dim), q_dtype)
69+
new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
70+
new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
71+
72+
cache_kv_c = gen_random(
73+
(total_num_pages, page_size // packing, packing, padded_lkv_dim),
74+
kv_dtype,
75+
)
76+
cache_k_pe = gen_random(
77+
(total_num_pages, page_size // packing, packing, padded_r_dim),
78+
kv_dtype)
79+
kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
80+
page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
81+
cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
82+
distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
83+
84+
ql_nope_for_kernel = ql_nope.copy()
85+
q_pe_for_kernel = q_pe.copy()
86+
87+
expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
88+
mla.ref_mla_ragged_paged_attention(
89+
ql_nope,
90+
q_pe,
91+
new_kv_c,
92+
new_k_pe,
93+
cache_kv_c.copy(),
94+
cache_k_pe.copy(),
95+
kv_lens,
96+
page_indices,
97+
cu_q_lens,
98+
distribution,
99+
sm_scale=sm_scale,
100+
sliding_window=sliding_window,
101+
soft_cap=soft_cap,
102+
))
103+
104+
kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105+
mla.mla_ragged_paged_attention(
106+
ql_nope_for_kernel,
107+
q_pe_for_kernel,
108+
new_kv_c,
109+
new_k_pe,
110+
cache_kv_c.copy(),
111+
cache_k_pe.copy(),
112+
kv_lens,
113+
page_indices,
114+
cu_q_lens,
115+
distribution,
116+
sm_scale=sm_scale,
117+
sliding_window=sliding_window,
118+
soft_cap=soft_cap,
119+
num_kv_pages_per_block=num_kv_pages_per_block,
120+
num_queries_per_block=num_queries_per_block,
121+
vmem_limit_bytes=vmem_limit_bytes,
122+
))
123+
124+
self.assertEqual(expected_out.shape,
125+
(total_q_len, num_heads, padded_lkv_dim))
126+
self.assertEqual(
127+
expected_updated_kv_c.shape,
128+
(total_num_pages, page_size // packing, packing, padded_lkv_dim),
129+
)
130+
self.assertEqual(
131+
expeceted_updated_k_pe.shape,
132+
(total_num_pages, page_size // packing, packing, padded_r_dim),
133+
)
134+
self.assertEqual(expected_out.dtype, kv_dtype)
135+
self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136+
self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
137+
138+
self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
139+
self.assertAllClose(expected_updated_kv_c,
140+
kernel_updated_kv_c,
141+
atol=0.2,
142+
rtol=0.2)
143+
self.assertAllClose(expeceted_updated_k_pe,
144+
kernel_updated_k_pe,
145+
atol=0.2,
146+
rtol=0.2)
147+
148+
def test_ragged_paged_attention_basic(self):
149+
dtype = jnp.bfloat16
150+
seq_lens = [(192, 328), (128, 180), (64, 255)]
151+
num_heads = 128
152+
lkv_dim = 512
153+
r_dim = 64
154+
page_size = 16
155+
num_pages = 1000
156+
157+
self._test_mla_ragged_paged_attention(
158+
seq_lens,
159+
num_heads,
160+
lkv_dim,
161+
r_dim,
162+
page_size,
163+
dtype,
164+
dtype,
165+
num_pages,
166+
)
167+
168+
@parameterized.product(dtype=[jnp.bfloat16], )
169+
def test_ragged_paged_attention_decode_only(self, dtype):
170+
seq_lens = [
171+
(1, 18),
172+
(1, 129),
173+
(1, 597),
174+
(1, 122),
175+
(1, 64),
176+
(1, 322),
177+
(1, 463),
178+
(1, 181),
179+
(1, 1107),
180+
(1, 123),
181+
(1, 31),
182+
(1, 18),
183+
(1, 1229),
184+
(1, 229),
185+
(1, 87),
186+
(1, 1328),
187+
]
188+
num_heads = 128
189+
lkv_dim = 512
190+
r_dim = 64
191+
page_size = 16
192+
num_pages = 1000
193+
194+
self._test_mla_ragged_paged_attention(
195+
seq_lens,
196+
num_heads,
197+
lkv_dim,
198+
r_dim,
199+
page_size,
200+
dtype,
201+
dtype,
202+
num_pages,
203+
)
204+
205+
@parameterized.product(dtype=[jnp.bfloat16], )
206+
def test_ragged_paged_attention_prefill_only(self, dtype):
207+
seq_lens = [
208+
(5, 18),
209+
(15, 129),
210+
(120, 597),
211+
(100, 122),
212+
(21, 64),
213+
(32, 322),
214+
(251, 463),
215+
(40, 181),
216+
(64, 1107),
217+
(99, 123),
218+
(10, 31),
219+
(5, 18),
220+
(3, 1229),
221+
(120, 229),
222+
(9, 87),
223+
(2, 1328),
224+
]
225+
num_heads = 128
226+
lkv_dim = 512
227+
r_dim = 64
228+
page_size = 16
229+
num_pages = 1000
230+
231+
self._test_mla_ragged_paged_attention(
232+
seq_lens,
233+
num_heads,
234+
lkv_dim,
235+
r_dim,
236+
page_size,
237+
dtype,
238+
dtype,
239+
num_pages,
240+
)
241+
242+
@parameterized.product(dtype=[jnp.bfloat16], )
243+
def test_ragged_paged_attention_mixed(self, dtype):
244+
seq_lens = [
245+
(5, 18),
246+
(1, 129),
247+
(120, 597),
248+
(1, 122),
249+
(1, 64),
250+
(32, 322),
251+
(251, 463),
252+
(1, 181),
253+
(1, 1107),
254+
(99, 123),
255+
(1, 31),
256+
(5, 18),
257+
(3, 1229),
258+
(117, 229),
259+
(1, 87),
260+
(1, 1328),
261+
]
262+
num_heads = 128
263+
lkv_dim = 512
264+
r_dim = 64
265+
page_size = 16
266+
num_pages = 1000
267+
268+
self._test_mla_ragged_paged_attention(
269+
seq_lens,
270+
num_heads,
271+
lkv_dim,
272+
r_dim,
273+
page_size,
274+
dtype,
275+
dtype,
276+
num_pages,
277+
)
278+
279+
@parameterized.product(sliding_window=[None, 5, 128], )
280+
def test_ragged_paged_attention_sliding_window(
281+
self,
282+
sliding_window: int | None,
283+
):
284+
num_seqs = 5
285+
num_heads = 128
286+
lkv_dim = 512
287+
r_dim = 64
288+
dtype = jnp.float32
289+
rng = np.random.default_rng(1234)
290+
q_lens = rng.integers(1, 100, num_seqs)
291+
kv_lens = q_lens + rng.integers(0, 50, num_seqs)
292+
seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
293+
page_size = 16
294+
num_pages = 1000
295+
296+
self._test_mla_ragged_paged_attention(
297+
seq_lens,
298+
num_heads,
299+
lkv_dim,
300+
r_dim,
301+
page_size,
302+
dtype,
303+
dtype,
304+
num_pages,
305+
sliding_window=sliding_window,
306+
)
307+
308+
@parameterized.product(soft_cap=[None, 50.0], )
309+
def test_ragged_paged_attention_logit_soft_capping(
310+
self,
311+
soft_cap: float | None,
312+
):
313+
num_heads = 128
314+
num_seqs = 2
315+
dtype = jnp.float32
316+
rng = np.random.default_rng(1234)
317+
q_lens = rng.integers(1, 100, num_seqs)
318+
kv_lens = q_lens + rng.integers(0, 50, num_seqs)
319+
seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
320+
lkv_dim = 512
321+
r_dim = 64
322+
page_size = 16
323+
num_pages = 1000
324+
325+
self._test_mla_ragged_paged_attention(
326+
seq_lens,
327+
num_heads,
328+
lkv_dim,
329+
r_dim,
330+
page_size,
331+
dtype,
332+
dtype,
333+
num_pages,
334+
soft_cap=soft_cap,
335+
)
336+
337+
def test_ragged_paged_attention_sliding_window_should_be_positive(self):
338+
dtype = jnp.float32
339+
seq_lens = [(192, 328), (128, 180), (64, 255)]
340+
num_heads = 128
341+
lkv_dim = 512
342+
r_dim = 64
343+
page_size = 16
344+
num_pages = 1000
345+
346+
with self.assertRaisesRegex(ValueError, "must be positive"):
347+
self._test_mla_ragged_paged_attention(
348+
seq_lens,
349+
num_heads,
350+
lkv_dim,
351+
r_dim,
352+
page_size,
353+
dtype,
354+
dtype,
355+
num_pages,
356+
sliding_window=0,
357+
)
358+
359+
with self.assertRaisesRegex(ValueError, "must be positive"):
360+
self._test_mla_ragged_paged_attention(
361+
seq_lens,
362+
num_heads,
363+
lkv_dim,
364+
r_dim,
365+
page_size,
366+
dtype,
367+
dtype,
368+
num_pages,
369+
sliding_window=-1,
370+
)
371+
372+
def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
373+
dtype = jnp.float32
374+
seq_lens = [(192, 328), (128, 180), (64, 255)]
375+
num_heads = 128
376+
lkv_dim = 512
377+
r_dim = 64
378+
page_size = 16
379+
num_pages = 1000
380+
381+
with self.assertRaisesRegex(ValueError, "must not be 0.0"):
382+
self._test_mla_ragged_paged_attention(
383+
seq_lens,
384+
num_heads,
385+
lkv_dim,
386+
r_dim,
387+
page_size,
388+
dtype,
389+
dtype,
390+
num_pages,
391+
soft_cap=0.0,
392+
)
393+
394+
395+
if __name__ == "__main__":
396+
absltest.main(testLoader=jtu.JaxTestLoader())

tpu_inference/kernels/mla/__init__.py

Whitespace-only changes.

tpu_inference/kernels/mla/v1/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)