Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
return kv_cache


# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")


def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}")
return 32 // bits


def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
padded_head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads = num_kv_heads * 2

# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing

kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (block_size * num_combined_kv_heads * padded_head_size *
kv_cache_dtype_bits // 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is 8 in the kv_cache_dtype_bits // 8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 bytes = 8 bits

5 changes: 3 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit = 32 * 1024 * 1024
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit = 16 * 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we reduce the vmem_limit from 32mb to 16mb?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in the comment, there might be register spills.

num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
Expand Down