Skip to content

Commit 068e169

Browse files
committed
fix comments
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent bb7fc2d commit 068e169

File tree

5 files changed

+55
-45
lines changed

5 files changed

+55
-45
lines changed

tests/v1/tpu/test_kv_cache_update_kernel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
@pytest.mark.parametrize("page_size", [32, 33])
1616
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
1717
@pytest.mark.parametrize("head_dim", [128, 256])
18-
@pytest.mark.parametrize("kernel_block_size", [4, 8])
18+
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
1919
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
20-
head_dim: int, kernel_block_size: int):
20+
head_dim: int, num_slices_per_block: int):
2121
page_num = 1000
2222
padded_num_tokens = 128
2323
kv_cache_cpu = torch.zeros(
@@ -42,11 +42,12 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
4242
np.cumsum(slice_lens[:-1])])
4343
slot_mapping = np.stack(
4444
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
45-
padded_size = (slot_mapping.shape[0] + kernel_block_size -
46-
1) // kernel_block_size * kernel_block_size
45+
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
46+
1) // num_slices_per_block * num_slices_per_block
4747
slot_mapping = np.pad(slot_mapping,
4848
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
4949
constant_values=0)
50+
slot_mapping = np.transpose(slot_mapping)
5051
slot_mapping_cpu = torch.tensor(slot_mapping,
5152
device="cpu",
5253
dtype=torch.int32)
@@ -56,7 +57,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
5657
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
5758
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
5859
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
59-
kernel_block_size)
60+
num_slices_per_block)
6061
kv_cache_xla.copy_(new_kv_cache_xla)
6162
torch_xla.sync()
6263

tests/v1/tpu/test_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class FakeAttentionLayer:
6565
context_lens=context_lens,
6666
query_start_loc=query_start_loc,
6767
num_seqs=num_seqs,
68-
kv_cache_update_block_size=8,
68+
num_slices_per_kv_cache_update_block=8,
6969
)
7070

7171
with patch("torch.ops.xla.ragged_paged_attention"

vllm/attention/ops/pallas_kv_cache_update.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,28 @@
1010

1111
def _kv_cache_update_kernel(
1212
# Prefetch
13-
slices_ref, # [num_slices, 3]
13+
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
14+
# slice_len)
1415
# Input
15-
new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim]
16-
kv_cache_hbm_ref,
16+
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
17+
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
18+
# head_dim]
1719
# Output
1820
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
1921
# Scratch
20-
scratch, # [block_size, page_size, num_combined_kv_heads, head_dim]
22+
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
23+
# head_dim]
2124
sem,
2225
):
2326
async_copies = []
2427
block_idx = pl.program_id(0)
25-
block_size = scratch.shape[0]
28+
num_slices_per_block = scratch.shape[0]
2629

2730
# Copy from new_kv_hbm_ref to scratch
28-
for i in range(block_size):
29-
offset_i = i + block_idx * block_size
30-
new_kv_start = slices_ref[offset_i, 1]
31-
length = slices_ref[offset_i, 2]
31+
for i in range(num_slices_per_block):
32+
offset_i = i + block_idx * num_slices_per_block
33+
new_kv_start = slices_ref[1, offset_i]
34+
length = slices_ref[2, offset_i]
3235
async_copy = pltpu.make_async_copy(
3336
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
3437
scratch.at[i, pl.ds(0, length), ...],
@@ -42,10 +45,10 @@ def _kv_cache_update_kernel(
4245

4346
# Copy from scratch to kv_cache_hbm_ref
4447
async_copies.clear()
45-
for i in range(block_size):
46-
offset_i = i + block_idx * block_size
47-
kv_cache_start = slices_ref[offset_i, 0]
48-
length = slices_ref[offset_i, 2]
48+
for i in range(num_slices_per_block):
49+
offset_i = i + block_idx * num_slices_per_block
50+
kv_cache_start = slices_ref[0, offset_i]
51+
length = slices_ref[2, offset_i]
4952
async_copy = pltpu.make_async_copy(
5053
scratch.at[i, pl.ds(0, length), ...],
5154
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
@@ -59,23 +62,25 @@ def _kv_cache_update_kernel(
5962

6063
@functools.partial(
6164
jax.jit,
62-
static_argnames=["page_size", "block_size"],
65+
static_argnames=["page_size", "num_slices_per_block"],
6366
)
6467
def kv_cache_update(
6568
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
6669
slices: jax.
67-
Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len)
70+
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
6871
kv_cache: jax.
6972
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
7073
*,
7174
page_size: int = 32,
72-
block_size: int = 8,
75+
num_slices_per_block: int = 8,
7376
):
74-
assert slices.shape[0] % block_size == 0
77+
assert slices.shape[1] % num_slices_per_block == 0
7578
_, num_combined_kv_heads, head_dim = new_kv.shape
7679
assert kv_cache.shape[1] == num_combined_kv_heads
7780
assert kv_cache.shape[2] == head_dim
7881
assert head_dim % 128 == 0
82+
# TODO: Add dynamic check to make sure that the all the slice lengths are
83+
# smaller or equal to page_size
7984

8085
in_specs = [
8186
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
@@ -87,7 +92,7 @@ def kv_cache_update(
8792

8893
scalar_prefetches = [slices]
8994
scratch = pltpu.VMEM(
90-
(block_size, page_size, num_combined_kv_heads, head_dim),
95+
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
9196
new_kv.dtype,
9297
)
9398

@@ -102,7 +107,7 @@ def kv_cache_update(
102107
num_scalar_prefetch=len(scalar_prefetches),
103108
in_specs=in_specs,
104109
out_specs=out_specs,
105-
grid=(slices.shape[0] // block_size, ),
110+
grid=(slices.shape[1] // num_slices_per_block, ),
106111
scratch_shapes=scratch_shapes,
107112
),
108113
out_shape=out_shape,

vllm/v1/attention/backends/pallas.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class PallasMetadata:
111111
context_lens: torch.Tensor
112112
query_start_loc: torch.Tensor
113113
num_seqs: torch.Tensor
114-
kv_cache_update_block_size: int
114+
num_slices_per_kv_cache_update_block: int
115115

116116

117117
class PallasAttentionBackendImpl(AttentionImpl):
@@ -217,10 +217,9 @@ def forward(
217217
# Write input keys and values to the KV cache.
218218
# Skip this if sharing KV cache with an earlier attention layer.
219219
slot_mapping = attn_metadata.slot_mapping
220-
kv_cache_update_block_size = \
221-
attn_metadata.kv_cache_update_block_size
222-
write_to_kv_cache(key, value, kv_cache, slot_mapping,
223-
kv_cache_update_block_size)
220+
write_to_kv_cache(
221+
key, value, kv_cache, slot_mapping,
222+
attn_metadata.num_slices_per_kv_cache_update_block)
224223

225224
output = torch.ops.xla.ragged_paged_attention(
226225
query,
@@ -252,15 +251,15 @@ def write_to_kv_cache(
252251
value: torch.Tensor,
253252
kv_cache: torch.Tensor,
254253
slot_mapping: torch.Tensor,
255-
kv_cache_update_block_size: int,
254+
num_slices_per_kv_cache_update_block: int,
256255
) -> None:
257256
""" Write the key and values to the KV cache.
258257
259258
Args:
260259
key: shape = [num_tokens, num_kv_heads * head_size]
261260
value: shape = [num_tokens, num_kv_heads * head_size]
262261
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
263-
kv_cache_update_block_size: int
262+
num_slices_per_kv_cache_update_block: int
264263
"""
265264
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
266265
head_size = cdiv(head_size,
@@ -272,39 +271,40 @@ def write_to_kv_cache(
272271

273272
kv_cache = kv_cache.flatten(0, 1)
274273
new_kv_cache = torch.ops.xla.kv_cache_update_op(
275-
kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size)
274+
kv, slot_mapping, kv_cache, page_size,
275+
num_slices_per_kv_cache_update_block)
276276
# NOTE: the in-place copy will be optimized away by XLA compiler.
277277
kv_cache.copy_(new_kv_cache)
278278

279279

280280
@requires_jax
281281
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
282282
kv_cache: torch.Tensor, page_size: int,
283-
block_size: int):
283+
num_slices_per_block: int):
284284
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
285285
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
286286
"page_size": page_size,
287-
"block_size": block_size
287+
"num_slices_per_block": num_slices_per_block
288288
})
289289
return new_kv_cache
290290

291291

292292
XLA_LIB.define(
293293
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
294-
"int page_size, int block_size) -> Tensor", )
294+
"int page_size, int num_slices_per_block) -> Tensor", )
295295

296296

297297
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
298298
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
299299
kv_cache: torch.Tensor, page_size: int,
300-
block_size: int) -> torch.Tensor:
300+
num_slices_per_block: int) -> torch.Tensor:
301301
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
302-
page_size, block_size)
302+
page_size, num_slices_per_block)
303303
return new_kv_cache
304304

305305

306306
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
307307
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
308308
kv_cache: torch.Tensor, page_size: int,
309-
block_size: int) -> torch.Tensor:
309+
num_slices_per_block: int) -> torch.Tensor:
310310
return kv_cache

vllm/v1/worker/tpu_model_runner.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
# Smallest output size
5858
MIN_NUM_SEQS = 8
5959
# Block size used for kv cache updating kernel
60-
KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE = 8
60+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
6161

6262

6363
#########################################################
@@ -720,6 +720,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
720720
slot_mapping_metadata,
721721
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
722722
constant_values=0)
723+
slot_mapping_metadata = np.transpose(slot_mapping_metadata)
723724
slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
724725
device=self.device)
725726

@@ -742,7 +743,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
742743
num_seqs=torch.tensor([num_reqs],
743744
dtype=torch.int32,
744745
device=self.device),
745-
kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE,
746+
num_slices_per_kv_cache_update_block=
747+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
746748
)
747749
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
748750
# request in the batch. While we should not sample any token from this
@@ -1170,7 +1172,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11701172
dtype=torch.int32).to(self.device)
11711173
padded_num_slices = _get_padded_num_kv_cache_update_slices(
11721174
num_tokens, self.max_num_reqs, self.block_size)
1173-
slot_mapping = torch.zeros((padded_num_slices, 3),
1175+
slot_mapping = torch.zeros((3, padded_num_slices),
11741176
dtype=torch.int32).to(self.device)
11751177
block_tables = torch.zeros(
11761178
(num_reqs, num_blocks),
@@ -1190,7 +1192,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11901192
context_lens=context_lens,
11911193
query_start_loc=query_start_loc,
11921194
num_seqs=num_seqs,
1193-
kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE,
1195+
num_slices_per_kv_cache_update_block=
1196+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
11941197
)
11951198

11961199
if self.is_multimodal_model:
@@ -1802,8 +1805,9 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
18021805
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
18031806
padded_num_slices = min(padded_num_slices, num_tokens)
18041807
padded_num_slices = (
1805-
padded_num_slices + KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE - 1
1806-
) // KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE * KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE
1808+
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
1809+
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
1810+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
18071811
return padded_num_slices
18081812

18091813

0 commit comments

Comments
 (0)