-
Notifications
You must be signed in to change notification settings - Fork 22
[Kernel] optimize kv cache update kernel block size #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest - have an autotune and benchmarking in google workspace internally first. Like we do for RPA and quantized_matmul There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! Since I'd like to make the kv cache update kernel on par with vLLM torch/xla path, we can add the auto-tuning and benchmarking later in google internally later? |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,14 @@ | |
| import functools | ||
|
|
||
| import jax | ||
| from jax._src import dtypes | ||
| from jax.experimental import pallas as pl | ||
| from jax.experimental.pallas import tpu as pltpu | ||
| from jax.sharding import Mesh | ||
| from jax.sharding import PartitionSpec as P | ||
|
|
||
| from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT, get_dtype_packing | ||
|
|
||
|
|
||
| def _ceil_div(a, b): | ||
| assert b != 0 | ||
|
|
@@ -140,8 +143,8 @@ def _kv_cache_update( | |
| page_size: int, | ||
| num_slices_per_block: int, | ||
| dynamic_validate_inputs: bool, | ||
| vmem_limit_bytes: int = 40 * 1024 * 1024, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is this calculated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically I'd like to have 32MB vmem to be used for scratch buffer, and it's round up to 40MB. |
||
| ): | ||
| assert slices.shape[1] % num_slices_per_block == 0 | ||
| new_token_num, num_combined_kv_heads, head_dim = new_kv.shape | ||
| assert kv_cache.shape[1] == num_combined_kv_heads | ||
| assert kv_cache.shape[2] == head_dim | ||
|
|
@@ -180,11 +183,52 @@ def _kv_cache_update( | |
| ), | ||
| out_shape=out_shape, | ||
| input_output_aliases={len(scalar_prefetches) + 1: 0}, | ||
| compiler_params=pltpu.CompilerParams( | ||
| vmem_limit_bytes=vmem_limit_bytes, ), | ||
| ) | ||
|
|
||
| return kernel(*scalar_prefetches, new_kv, kv_cache)[0] | ||
|
|
||
|
|
||
| def _prev_power_of_2(n: int) -> int: | ||
| """The previous power of 2 (inclusive)""" | ||
| if n <= 0: | ||
| return 0 | ||
| return 1 << (n.bit_length() - 1) | ||
|
|
||
|
|
||
| def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int, | ||
| head_size: int, kv_cache_dtype) -> int: | ||
| """Returns the size in bytes of one page of the KV cache.""" | ||
| kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype) | ||
| padded_head_size = _ceil_div( | ||
| head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT | ||
|
|
||
| # NOTE: for the implicit padding in XLA | ||
| packing = get_dtype_packing(kv_cache_dtype) | ||
| num_combined_kv_heads = _ceil_div(num_combined_kv_heads, packing) * packing | ||
|
|
||
| return block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bit_size // 8 | ||
|
|
||
|
|
||
| def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int, | ||
| vmem_limit_bytes: int) -> int: | ||
| """Find the optimum number of slices to copy per Pallas program instance. | ||
| Increasing the number of slices copied in one instance of the kernel program | ||
| will increase HBM bandwidth utilization via more in-flight DMAs. | ||
| However, it will also use more VMEM, and experimentally, we observed | ||
| performance regression at 128 slices on v6e, likely due to running | ||
| out of scalar registers. Thus this function will limit the number of | ||
| slices to 64. | ||
| """ | ||
| # NOTE: We assume 1MB vmem is used for register spill and others | ||
| assert vmem_limit_bytes >= 1024 * 1024, "vmem_limit_bytes must be at least 1MB" | ||
| num_slices_per_block = (vmem_limit_bytes - 1024 * 1024) // page_size_bytes | ||
| assert num_slices_per_block > 0, "Number of slices should be positive" | ||
| num_slices_per_block = _prev_power_of_2(num_slices_per_block) | ||
| return min(num_slices_per_block, 64) | ||
|
|
||
|
|
||
| @functools.partial( | ||
| jax.jit, | ||
| static_argnames=[ | ||
|
|
@@ -201,12 +245,21 @@ def kv_cache_update( | |
| num_slices: jax.Array, # [1] | ||
| *, | ||
| page_size: int = 32, | ||
| num_slices_per_block: int = 8, | ||
| num_slices_per_block: int | None = None, | ||
| mesh: Mesh | None = None, | ||
| kv_cache_pspec: P | ||
| | None = None, # Only sharding along head_dim is supported | ||
| dynamic_validate_inputs: bool = False, | ||
| vmem_limit_bytes: int = 40 * 1024 * 1024, | ||
| ): | ||
| if num_slices_per_block is None: | ||
| _, num_combined_kv_heads, head_dim = new_kv.shape | ||
| page_size_bytes = _get_page_size_bytes(page_size, | ||
| num_combined_kv_heads, head_dim, | ||
| kv_cache.dtype) | ||
| num_slices_per_block = _get_num_slices_per_kv_cache_update_block( | ||
| page_size_bytes, vmem_limit_bytes) | ||
|
|
||
| if mesh is None: | ||
| return _kv_cache_update(new_kv, slices, kv_cache, num_slices, | ||
| page_size, num_slices_per_block, | ||
|
|
@@ -224,6 +277,7 @@ def kv_cache_update( | |
| page_size=page_size, | ||
| num_slices_per_block=num_slices_per_block, | ||
| dynamic_validate_inputs=dynamic_validate_inputs, | ||
| vmem_limit_bytes=vmem_limit_bytes, | ||
| ), | ||
| mesh=mesh, | ||
| in_specs=in_specs, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: better to test the throughput before and after this change. Kernel won't be affected but next op's prefetch will be affected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would crash due to vmem OOM before the change.