Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions tests/kernels/ragged_kv_cache_update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def kv_cache_update_ref(new_kv, slot_mapping, kv_cache):
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class KVCacheUpdateTest(jtu.JaxTestCase):

def _generate_data(self, page_size, combined_kv_head_num, head_dim,
num_slices_per_block):
def _generate_data(self, page_size, combined_kv_head_num, head_dim):
page_num = 20
padded_num_tokens = 128
prng_key = jax.random.key(1234)
Expand All @@ -45,12 +44,6 @@ def _generate_data(self, page_size, combined_kv_head_num, head_dim,
np.cumsum(slice_lens[:-1])])
slot_mapping_np = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
padded_size = (slot_mapping_np.shape[0] + num_slices_per_block -
1) // num_slices_per_block * num_slices_per_block
slot_mapping_np = np.pad(
slot_mapping_np,
[[0, padded_size - slot_mapping_np.shape[0]], [0, 0]],
constant_values=0)
slot_mapping_np = np.transpose(slot_mapping_np)
slot_mapping = jnp.array(slot_mapping_np, dtype=jnp.int32)
return new_kv, slot_mapping, kv_cache, num_slices
Expand All @@ -59,14 +52,14 @@ def _generate_data(self, page_size, combined_kv_head_num, head_dim,
page_size=[32, 33],
combined_kv_head_num=[2, 16],
head_dim=[128, 256],
num_slices_per_block=[4, 8],
num_slices_per_block=[None, 8],
dynamic_validate_inputs=[False, True],
)
def test_basic(self, page_size: int, combined_kv_head_num: int,
head_dim: int, num_slices_per_block: int,
dynamic_validate_inputs: bool):
new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
page_size, combined_kv_head_num, head_dim, num_slices_per_block)
page_size, combined_kv_head_num, head_dim)
old_kv_cache_copy = kv_cache.copy()

with jax.disable_jit(disable=dynamic_validate_inputs):
Expand All @@ -90,12 +83,12 @@ def test_basic(self, page_size: int, combined_kv_head_num: int,
page_size=[32, 33],
combined_kv_head_num=[16, 32],
head_dim=[128, 256],
num_slices_per_block=[4, 8],
num_slices_per_block=[None, 8],
)
def test_torchax_shard_map(self, page_size: int, combined_kv_head_num: int,
head_dim: int, num_slices_per_block: int):
new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
page_size, combined_kv_head_num, head_dim, num_slices_per_block)
page_size, combined_kv_head_num, head_dim)
old_kv_cache_copy = kv_cache.copy()

mesh = Mesh(jax.devices(), 'x')
Expand Down Expand Up @@ -127,10 +120,9 @@ def test_invalid_inputs(self):
page_size = 32
combined_kv_head_num = 2
head_dim = 128
num_slices_per_block = 4

new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
page_size, combined_kv_head_num, head_dim, num_slices_per_block)
page_size, combined_kv_head_num, head_dim)

with jax.disable_jit():
# Case 1: new_kv_start < 0
Expand Down
6 changes: 2 additions & 4 deletions tests/models/vllm/test_pallas_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig

from tpu_commons.attention.backends.pallas_torchax import (
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, PallasAttentionBackend,
PallasAttentionBackendImpl, PallasMetadata, write_to_kv_cache)
PallasAttentionBackend, PallasAttentionBackendImpl, PallasMetadata,
write_to_kv_cache)


class TestPallasMetadata:
Expand Down Expand Up @@ -480,8 +480,6 @@ def test_write_to_kv_cache(mock_kv_cache_update, mock_call_jax):
args, kwargs = mock_call_jax.call_args
assert args[0] == mock_kv_cache_update
assert kwargs['page_size'] == 16
assert kwargs[
'num_slices_per_block'] == NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK


def test_write_to_kv_cache_tensor_shapes():
Expand Down
3 changes: 1 addition & 2 deletions tests/worker/test_tpu_worker_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ def test_init_device(self, mock_envs, mock_os, mock_torch, mock_jax,
else:
mock_report_usage_stats.assert_not_called()

@patch('tpu_commons.worker.tpu_worker_torchax.TPU_HEAD_SIZE_ALIGNMENT',
128)
@patch('tpu_commons.utils.TPU_HEAD_SIZE_ALIGNMENT', 128)
@patch('tpu_commons.worker.tpu_worker_torchax.jax')
@patch('tpu_commons.worker.tpu_worker_torchax.logger')
@pytest.mark.parametrize(
Expand Down
22 changes: 8 additions & 14 deletions tpu_commons/attention/backends/pallas_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
# Register custom op dispatcher.
from tpu_commons.models.torchax.torchax_wrapper import (kv_cache_update,
ragged_paged_attention)
from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT

logger = init_logger(__name__)

# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -233,7 +229,7 @@ def forward(
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
vmem_limit_bytes=100 * 1024 * 1024,
Copy link
Collaborator

@bythew3i bythew3i Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: better to test the throughput before and after this change. Kernel won't be affected but next op's prefetch will be affected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would crash due to vmem OOM before the change.

use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
Expand Down Expand Up @@ -270,14 +266,12 @@ def write_to_kv_cache(key: torch.Tensor, value: torch.Tensor,
head_size)

kv_cache = kv_cache.reshape(-1, num_combined_kv_heads, head_size)
kv_cache = call_jax(
kv_cache_update,
kv,
slot_mapping,
kv_cache,
num_slices,
page_size=block_size,
num_slices_per_block=NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK)
kv_cache = call_jax(kv_cache_update,
kv,
slot_mapping,
kv_cache,
num_slices,
page_size=block_size)
kv_cache = kv_cache.reshape(num_blocks, block_size, num_combined_kv_heads,
head_size)
return kv_cache
58 changes: 56 additions & 2 deletions tpu_commons/kernels/ragged_kv_cache_update.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest - have an autotune and benchmarking in google workspace internally first. Like we do for RPA and quantized_matmul

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Since I'd like to make the kv cache update kernel on par with vLLM torch/xla path, we can add the auto-tuning and benchmarking later in google internally later?

Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import functools

import jax
from jax._src import dtypes
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT, get_dtype_packing


def _ceil_div(a, b):
assert b != 0
Expand Down Expand Up @@ -140,8 +143,8 @@ def _kv_cache_update(
page_size: int,
num_slices_per_block: int,
dynamic_validate_inputs: bool,
vmem_limit_bytes: int = 40 * 1024 * 1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this calculated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I'd like to have 32MB vmem to be used for scratch buffer, and it's round up to 40MB.

):
assert slices.shape[1] % num_slices_per_block == 0
new_token_num, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
Expand Down Expand Up @@ -180,11 +183,52 @@ def _kv_cache_update(
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
compiler_params=pltpu.CompilerParams(
vmem_limit_bytes=vmem_limit_bytes, ),
)

return kernel(*scalar_prefetches, new_kv, kv_cache)[0]


def _prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)


def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
head_size: int, kv_cache_dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
padded_head_size = _ceil_div(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT

# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = _ceil_div(num_combined_kv_heads, packing) * packing

return block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bit_size // 8


def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int,
vmem_limit_bytes: int) -> int:
"""Find the optimum number of slices to copy per Pallas program instance.
Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.
However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# NOTE: We assume 1MB vmem is used for register spill and others
assert vmem_limit_bytes >= 1024 * 1024, "vmem_limit_bytes must be at least 1MB"
num_slices_per_block = (vmem_limit_bytes - 1024 * 1024) // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = _prev_power_of_2(num_slices_per_block)
return min(num_slices_per_block, 64)


@functools.partial(
jax.jit,
static_argnames=[
Expand All @@ -201,12 +245,21 @@ def kv_cache_update(
num_slices: jax.Array, # [1]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
num_slices_per_block: int | None = None,
mesh: Mesh | None = None,
kv_cache_pspec: P
| None = None, # Only sharding along head_dim is supported
dynamic_validate_inputs: bool = False,
vmem_limit_bytes: int = 40 * 1024 * 1024,
):
if num_slices_per_block is None:
_, num_combined_kv_heads, head_dim = new_kv.shape
page_size_bytes = _get_page_size_bytes(page_size,
num_combined_kv_heads, head_dim,
kv_cache.dtype)
num_slices_per_block = _get_num_slices_per_kv_cache_update_block(
page_size_bytes, vmem_limit_bytes)

if mesh is None:
return _kv_cache_update(new_kv, slices, kv_cache, num_slices,
page_size, num_slices_per_block,
Expand All @@ -224,6 +277,7 @@ def kv_cache_update(
page_size=page_size,
num_slices_per_block=num_slices_per_block,
dynamic_validate_inputs=dynamic_validate_inputs,
vmem_limit_bytes=vmem_limit_bytes,
),
mesh=mesh,
in_specs=in_specs,
Expand Down
20 changes: 7 additions & 13 deletions tpu_commons/models/jax/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
ragged_paged_attention
from tpu_commons.models.jax.attention_metadata import AttentionMetadata

# TODO(xiang): put this in attention metadata
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8


def sharded_ragged_paged_attention(sm_scale: float,
mesh: Mesh,
Expand Down Expand Up @@ -114,14 +110,12 @@ def update_kv_cache(k: jax.Array, v: jax.Array, kv_cache: jax.Array,
kv = jnp.concat([k, v], axis=-1).reshape(T, K_2, H)

kv_cache = kv_cache.reshape(-1, K_2, H)
kv_cache = kv_cache_update(
kv,
slices,
kv_cache,
num_slices,
page_size=S,
num_slices_per_block=NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
mesh=mesh,
kv_cache_pspec=P(None, "model", None))
kv_cache = kv_cache_update(kv,
slices,
kv_cache,
num_slices,
page_size=S,
mesh=mesh,
kv_cache_pspec=P(None, "model", None))
kv_cache = kv_cache.reshape(L, S, K_2, H)
return kv_cache
2 changes: 1 addition & 1 deletion tpu_commons/models/torchax/torchax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _kv_cache_update(
num_slices: jax.Array, # [1]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
num_slices_per_block: int = None,
) -> Array:
# TODO: Get rid of this wrapper and call from pallas.py directly. Need to
# find a better way to get mesh in pallas.py.
Expand Down
6 changes: 0 additions & 6 deletions tpu_commons/runner/jax/tpu_jax_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8

DUMMY_METADATA = AttentionMetadata(
input_positions=[],
Expand Down Expand Up @@ -1192,8 +1190,4 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
return padded_num_slices
7 changes: 1 addition & 6 deletions tpu_commons/runner/tpu_torchax_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
is_pin_memory_available)

from tpu_commons.attention.backends.pallas_torchax import (
PallasAttentionBackend, PallasMetadata,
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK)
PallasAttentionBackend, PallasMetadata)

from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
Expand Down Expand Up @@ -1108,8 +1107,4 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
return padded_num_slices
9 changes: 8 additions & 1 deletion tpu_commons/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from typing import Any, List, Tuple

import jax
from jax._src import dtypes
from vllm import envs

from tpu_commons.logger import init_logger
from vllm import envs

GBYTES = 1024 * 1024 * 1024
TPU_HEAD_SIZE_ALIGNMENT = 128

_megacore = False
logger = init_logger(__name__)
Expand Down Expand Up @@ -99,3 +101,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
assert sharding_size % num_heads == 0
num_heads = sharding_size
return num_heads


def get_dtype_packing(dtype):
bits = dtypes.bit_width(dtype)
return 32 // bits
2 changes: 1 addition & 1 deletion tpu_commons/worker/tpu_worker_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -40,6 +39,7 @@
from tpu_commons.worker._temporary_vllm_compat import (
adapt_kv_cache_config_if_needed, adapt_scheduler_output_if_needed,
adapt_lora_request_if_needed)
from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT

logger = init_logger(__name__)

Expand Down
Loading