diff --git a/tests/kernels/ragged_kv_cache_update_test.py b/tests/kernels/ragged_kv_cache_update_test.py index f0a0189de..1269b513e 100644 --- a/tests/kernels/ragged_kv_cache_update_test.py +++ b/tests/kernels/ragged_kv_cache_update_test.py @@ -21,8 +21,7 @@ def kv_cache_update_ref(new_kv, slot_mapping, kv_cache): @jtu.with_config(jax_numpy_dtype_promotion="standard") class KVCacheUpdateTest(jtu.JaxTestCase): - def _generate_data(self, page_size, combined_kv_head_num, head_dim, - num_slices_per_block): + def _generate_data(self, page_size, combined_kv_head_num, head_dim): page_num = 20 padded_num_tokens = 128 prng_key = jax.random.key(1234) @@ -45,12 +44,6 @@ def _generate_data(self, page_size, combined_kv_head_num, head_dim, np.cumsum(slice_lens[:-1])]) slot_mapping_np = np.stack( [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) - padded_size = (slot_mapping_np.shape[0] + num_slices_per_block - - 1) // num_slices_per_block * num_slices_per_block - slot_mapping_np = np.pad( - slot_mapping_np, - [[0, padded_size - slot_mapping_np.shape[0]], [0, 0]], - constant_values=0) slot_mapping_np = np.transpose(slot_mapping_np) slot_mapping = jnp.array(slot_mapping_np, dtype=jnp.int32) return new_kv, slot_mapping, kv_cache, num_slices @@ -59,14 +52,14 @@ def _generate_data(self, page_size, combined_kv_head_num, head_dim, page_size=[32, 33], combined_kv_head_num=[2, 16], head_dim=[128, 256], - num_slices_per_block=[4, 8], + num_slices_per_block=[None, 8], dynamic_validate_inputs=[False, True], ) def test_basic(self, page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int, dynamic_validate_inputs: bool): new_kv, slot_mapping, kv_cache, num_slices = self._generate_data( - page_size, combined_kv_head_num, head_dim, num_slices_per_block) + page_size, combined_kv_head_num, head_dim) old_kv_cache_copy = kv_cache.copy() with jax.disable_jit(disable=dynamic_validate_inputs): @@ -90,12 +83,12 @@ def test_basic(self, page_size: int, combined_kv_head_num: int, page_size=[32, 33], combined_kv_head_num=[16, 32], head_dim=[128, 256], - num_slices_per_block=[4, 8], + num_slices_per_block=[None, 8], ) def test_torchax_shard_map(self, page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int): new_kv, slot_mapping, kv_cache, num_slices = self._generate_data( - page_size, combined_kv_head_num, head_dim, num_slices_per_block) + page_size, combined_kv_head_num, head_dim) old_kv_cache_copy = kv_cache.copy() mesh = Mesh(jax.devices(), 'x') @@ -127,10 +120,9 @@ def test_invalid_inputs(self): page_size = 32 combined_kv_head_num = 2 head_dim = 128 - num_slices_per_block = 4 new_kv, slot_mapping, kv_cache, num_slices = self._generate_data( - page_size, combined_kv_head_num, head_dim, num_slices_per_block) + page_size, combined_kv_head_num, head_dim) with jax.disable_jit(): # Case 1: new_kv_start < 0 diff --git a/tests/models/vllm/test_pallas_torchax.py b/tests/models/vllm/test_pallas_torchax.py index 18b9ec6f4..0cc80326a 100644 --- a/tests/models/vllm/test_pallas_torchax.py +++ b/tests/models/vllm/test_pallas_torchax.py @@ -6,8 +6,8 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from tpu_commons.attention.backends.pallas_torchax import ( - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, PallasAttentionBackend, - PallasAttentionBackendImpl, PallasMetadata, write_to_kv_cache) + PallasAttentionBackend, PallasAttentionBackendImpl, PallasMetadata, + write_to_kv_cache) class TestPallasMetadata: @@ -480,8 +480,6 @@ def test_write_to_kv_cache(mock_kv_cache_update, mock_call_jax): args, kwargs = mock_call_jax.call_args assert args[0] == mock_kv_cache_update assert kwargs['page_size'] == 16 - assert kwargs[ - 'num_slices_per_block'] == NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK def test_write_to_kv_cache_tensor_shapes(): diff --git a/tests/worker/test_tpu_worker_torchax.py b/tests/worker/test_tpu_worker_torchax.py index dd6ef03ab..3a5115297 100644 --- a/tests/worker/test_tpu_worker_torchax.py +++ b/tests/worker/test_tpu_worker_torchax.py @@ -261,8 +261,7 @@ def test_init_device(self, mock_envs, mock_os, mock_torch, mock_jax, else: mock_report_usage_stats.assert_not_called() - @patch('tpu_commons.worker.tpu_worker_torchax.TPU_HEAD_SIZE_ALIGNMENT', - 128) + @patch('tpu_commons.utils.TPU_HEAD_SIZE_ALIGNMENT', 128) @patch('tpu_commons.worker.tpu_worker_torchax.jax') @patch('tpu_commons.worker.tpu_worker_torchax.logger') @pytest.mark.parametrize( diff --git a/tpu_commons/attention/backends/pallas_torchax.py b/tpu_commons/attention/backends/pallas_torchax.py index a88b99705..b47453c8e 100644 --- a/tpu_commons/attention/backends/pallas_torchax.py +++ b/tpu_commons/attention/backends/pallas_torchax.py @@ -17,14 +17,10 @@ # Register custom op dispatcher. from tpu_commons.models.torchax.torchax_wrapper import (kv_cache_update, ragged_paged_attention) +from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT logger = init_logger(__name__) -# TPU requires the head size to be a multiple of 128. -TPU_HEAD_SIZE_ALIGNMENT = 128 -# Block size used for kv cache updating kernel -NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 - class PallasAttentionBackend(AttentionBackend): @@ -233,7 +229,7 @@ def forward( # these can be manually adjusted for debugging if necessary. num_kv_pages_per_block=None, num_queries_per_block=None, - vmem_limit_bytes=None, + vmem_limit_bytes=100 * 1024 * 1024, use_kernel=True, sm_scale=self.scale, sliding_window=self.sliding_window, @@ -270,14 +266,12 @@ def write_to_kv_cache(key: torch.Tensor, value: torch.Tensor, head_size) kv_cache = kv_cache.reshape(-1, num_combined_kv_heads, head_size) - kv_cache = call_jax( - kv_cache_update, - kv, - slot_mapping, - kv_cache, - num_slices, - page_size=block_size, - num_slices_per_block=NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK) + kv_cache = call_jax(kv_cache_update, + kv, + slot_mapping, + kv_cache, + num_slices, + page_size=block_size) kv_cache = kv_cache.reshape(num_blocks, block_size, num_combined_kv_heads, head_size) return kv_cache diff --git a/tpu_commons/kernels/ragged_kv_cache_update.py b/tpu_commons/kernels/ragged_kv_cache_update.py index 8fb8bbbdb..7d4963b3a 100644 --- a/tpu_commons/kernels/ragged_kv_cache_update.py +++ b/tpu_commons/kernels/ragged_kv_cache_update.py @@ -4,11 +4,14 @@ import functools import jax +from jax._src import dtypes from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT, get_dtype_packing + def _ceil_div(a, b): assert b != 0 @@ -140,8 +143,8 @@ def _kv_cache_update( page_size: int, num_slices_per_block: int, dynamic_validate_inputs: bool, + vmem_limit_bytes: int = 40 * 1024 * 1024, ): - assert slices.shape[1] % num_slices_per_block == 0 new_token_num, num_combined_kv_heads, head_dim = new_kv.shape assert kv_cache.shape[1] == num_combined_kv_heads assert kv_cache.shape[2] == head_dim @@ -180,11 +183,52 @@ def _kv_cache_update( ), out_shape=out_shape, input_output_aliases={len(scalar_prefetches) + 1: 0}, + compiler_params=pltpu.CompilerParams( + vmem_limit_bytes=vmem_limit_bytes, ), ) return kernel(*scalar_prefetches, new_kv, kv_cache)[0] +def _prev_power_of_2(n: int) -> int: + """The previous power of 2 (inclusive)""" + if n <= 0: + return 0 + return 1 << (n.bit_length() - 1) + + +def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int, + head_size: int, kv_cache_dtype) -> int: + """Returns the size in bytes of one page of the KV cache.""" + kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype) + padded_head_size = _ceil_div( + head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + + # NOTE: for the implicit padding in XLA + packing = get_dtype_packing(kv_cache_dtype) + num_combined_kv_heads = _ceil_div(num_combined_kv_heads, packing) * packing + + return block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bit_size // 8 + + +def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int, + vmem_limit_bytes: int) -> int: + """Find the optimum number of slices to copy per Pallas program instance. + Increasing the number of slices copied in one instance of the kernel program + will increase HBM bandwidth utilization via more in-flight DMAs. + However, it will also use more VMEM, and experimentally, we observed + performance regression at 128 slices on v6e, likely due to running + out of scalar registers. Thus this function will limit the number of + slices to 64. + """ + # NOTE: We assume 1MB vmem is used for register spill and others + assert vmem_limit_bytes >= 1024 * 1024, "vmem_limit_bytes must be at least 1MB" + num_slices_per_block = (vmem_limit_bytes - 1024 * 1024) // page_size_bytes + assert num_slices_per_block > 0, "Number of slices should be positive" + num_slices_per_block = _prev_power_of_2(num_slices_per_block) + return min(num_slices_per_block, 64) + + @functools.partial( jax.jit, static_argnames=[ @@ -201,12 +245,21 @@ def kv_cache_update( num_slices: jax.Array, # [1] *, page_size: int = 32, - num_slices_per_block: int = 8, + num_slices_per_block: int | None = None, mesh: Mesh | None = None, kv_cache_pspec: P | None = None, # Only sharding along head_dim is supported dynamic_validate_inputs: bool = False, + vmem_limit_bytes: int = 40 * 1024 * 1024, ): + if num_slices_per_block is None: + _, num_combined_kv_heads, head_dim = new_kv.shape + page_size_bytes = _get_page_size_bytes(page_size, + num_combined_kv_heads, head_dim, + kv_cache.dtype) + num_slices_per_block = _get_num_slices_per_kv_cache_update_block( + page_size_bytes, vmem_limit_bytes) + if mesh is None: return _kv_cache_update(new_kv, slices, kv_cache, num_slices, page_size, num_slices_per_block, @@ -224,6 +277,7 @@ def kv_cache_update( page_size=page_size, num_slices_per_block=num_slices_per_block, dynamic_validate_inputs=dynamic_validate_inputs, + vmem_limit_bytes=vmem_limit_bytes, ), mesh=mesh, in_specs=in_specs, diff --git a/tpu_commons/models/jax/attention_interface.py b/tpu_commons/models/jax/attention_interface.py index 0b075efa0..0a462842c 100644 --- a/tpu_commons/models/jax/attention_interface.py +++ b/tpu_commons/models/jax/attention_interface.py @@ -11,10 +11,6 @@ ragged_paged_attention from tpu_commons.models.jax.attention_metadata import AttentionMetadata -# TODO(xiang): put this in attention metadata -# Block size used for kv cache updating kernel -NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 - def sharded_ragged_paged_attention(sm_scale: float, mesh: Mesh, @@ -114,14 +110,12 @@ def update_kv_cache(k: jax.Array, v: jax.Array, kv_cache: jax.Array, kv = jnp.concat([k, v], axis=-1).reshape(T, K_2, H) kv_cache = kv_cache.reshape(-1, K_2, H) - kv_cache = kv_cache_update( - kv, - slices, - kv_cache, - num_slices, - page_size=S, - num_slices_per_block=NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, - mesh=mesh, - kv_cache_pspec=P(None, "model", None)) + kv_cache = kv_cache_update(kv, + slices, + kv_cache, + num_slices, + page_size=S, + mesh=mesh, + kv_cache_pspec=P(None, "model", None)) kv_cache = kv_cache.reshape(L, S, K_2, H) return kv_cache diff --git a/tpu_commons/models/torchax/torchax_wrapper.py b/tpu_commons/models/torchax/torchax_wrapper.py index 409dfc29e..2408c9c0a 100644 --- a/tpu_commons/models/torchax/torchax_wrapper.py +++ b/tpu_commons/models/torchax/torchax_wrapper.py @@ -199,7 +199,7 @@ def _kv_cache_update( num_slices: jax.Array, # [1] *, page_size: int = 32, - num_slices_per_block: int = 8, + num_slices_per_block: int = None, ) -> Array: # TODO: Get rid of this wrapper and call from pallas.py directly. Need to # find a better way to get mesh in pallas.py. diff --git a/tpu_commons/runner/jax/tpu_jax_runner.py b/tpu_commons/runner/jax/tpu_jax_runner.py index e7f7d509f..ed0be4649 100644 --- a/tpu_commons/runner/jax/tpu_jax_runner.py +++ b/tpu_commons/runner/jax/tpu_jax_runner.py @@ -47,8 +47,6 @@ INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 -# Block size used for kv cache updating kernel -NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 DUMMY_METADATA = AttentionMetadata( input_positions=[], @@ -1192,8 +1190,4 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) - padded_num_slices = ( - padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 - ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK return padded_num_slices diff --git a/tpu_commons/runner/tpu_torchax_runner.py b/tpu_commons/runner/tpu_torchax_runner.py index cf7acb917..112d080c6 100644 --- a/tpu_commons/runner/tpu_torchax_runner.py +++ b/tpu_commons/runner/tpu_torchax_runner.py @@ -46,8 +46,7 @@ is_pin_memory_available) from tpu_commons.attention.backends.pallas_torchax import ( - PallasAttentionBackend, PallasMetadata, - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK) + PallasAttentionBackend, PallasMetadata) from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -1108,8 +1107,4 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) - padded_num_slices = ( - padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 - ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK return padded_num_slices diff --git a/tpu_commons/utils.py b/tpu_commons/utils.py index 69624a851..d5cd73eb6 100644 --- a/tpu_commons/utils.py +++ b/tpu_commons/utils.py @@ -4,11 +4,13 @@ from typing import Any, List, Tuple import jax +from jax._src import dtypes +from vllm import envs from tpu_commons.logger import init_logger -from vllm import envs GBYTES = 1024 * 1024 * 1024 +TPU_HEAD_SIZE_ALIGNMENT = 128 _megacore = False logger = init_logger(__name__) @@ -99,3 +101,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int: assert sharding_size % num_heads == 0 num_heads = sharding_size return num_heads + + +def get_dtype_packing(dtype): + bits = dtypes.bit_width(dtype) + return 32 // bits diff --git a/tpu_commons/worker/tpu_worker_torchax.py b/tpu_commons/worker/tpu_worker_torchax.py index 7c43b4c89..ef6a7f270 100644 --- a/tpu_commons/worker/tpu_worker_torchax.py +++ b/tpu_commons/worker/tpu_worker_torchax.py @@ -21,7 +21,6 @@ init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -40,6 +39,7 @@ from tpu_commons.worker._temporary_vllm_compat import ( adapt_kv_cache_config_if_needed, adapt_scheduler_output_if_needed, adapt_lora_request_if_needed) +from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT logger = init_logger(__name__)