Skip to content

Commit 5952d8a

Browse files
[Attention] Get rid of mla cache alignment (#14842)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent a2ae496 commit 5952d8a

File tree

4 files changed

+14
-83
lines changed

4 files changed

+14
-83
lines changed

tests/kernels/test_cache.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
99
from vllm import _custom_ops as ops
1010
from vllm.platforms import current_platform
11-
from vllm.utils import align_to_256bytes
1211

1312
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
1413
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -450,22 +449,13 @@ def _create_mla_cache(
450449
dtype: torch.dtype,
451450
kv_cache_dtype: str,
452451
device: str,
453-
align_cache: bool,
454452
) -> torch.Tensor:
455453
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
456-
457-
if align_cache:
458-
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
459-
alloc_shape = (num_blocks, block_size, alloc_entry_size)
460-
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
461-
cache = cache_full[..., :entry_size]
462-
else:
463-
cache = torch.zeros(num_blocks,
464-
block_size,
465-
entry_size,
466-
dtype=cache_dtype,
467-
device=device)
468-
return cache
454+
return torch.zeros(num_blocks,
455+
block_size,
456+
entry_size,
457+
dtype=cache_dtype,
458+
device=device)
469459

470460

471461
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
488478
@pytest.mark.parametrize("seed", SEEDS)
489479
@pytest.mark.parametrize("device", CUDA_DEVICES)
490480
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
491-
@pytest.mark.parametrize("align_cache", [False])
492481
@torch.inference_mode()
493482
def test_concat_and_cache_mla(
494483
kv_lora_rank: int,
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
500489
seed: int,
501490
device: str,
502491
kv_cache_dtype: str,
503-
align_cache: bool,
504492
) -> None:
505493
current_platform.seed_everything(seed)
506494
torch.set_default_device(device)
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
520508

521509
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
522510
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
523-
kv_cache_dtype, device, align_cache)
511+
kv_cache_dtype, device)
524512
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
525513

526514
for i in range(num_tokens):
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
576564
@pytest.mark.parametrize("seed", SEEDS)
577565
@pytest.mark.parametrize("device", CUDA_DEVICES)
578566
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
579-
@pytest.mark.parametrize("align_cache", [False, True])
580567
@torch.inference_mode()
581568
def test_copy_blocks_mla(
582569
kv_lora_rank: int,
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
588575
seed: int,
589576
device: str,
590577
kv_cache_dtype: str,
591-
align_cache: bool,
592578
) -> None:
593579
current_platform.seed_everything(seed)
594580
torch.set_default_device(device)
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
598584
kv_caches = []
599585
for _ in range(num_layers):
600586
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
601-
kv_cache_dtype, device, align_cache)
587+
kv_cache_dtype, device)
602588
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
603589
kv_caches.append(kv_cache)
604590

@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
642628
@pytest.mark.parametrize("seed", SEEDS)
643629
@pytest.mark.parametrize("device", CUDA_DEVICES)
644630
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
645-
@pytest.mark.parametrize("align_cache", [False, True])
646631
@torch.inference_mode()
647632
def test_swap_blocks_mla(
648633
kv_lora_rank: int,
@@ -653,17 +638,16 @@ def test_swap_blocks_mla(
653638
seed: int,
654639
device: str,
655640
kv_cache_dtype: str,
656-
align_cache: bool,
657641
) -> None:
658642
current_platform.seed_everything(seed)
659643
torch.set_default_device(device)
660644

661645
entry_size = kv_lora_rank + qk_rope_head_dim
662646

663647
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
664-
kv_cache_dtype, device, align_cache)
648+
kv_cache_dtype, device)
665649
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
666-
kv_cache_dtype, device, align_cache)
650+
kv_cache_dtype, device)
667651

668652
_fill_mla_cache(src_cache, kv_cache_dtype)
669653
_fill_mla_cache(dst_cache, kv_cache_dtype)
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
704688
@pytest.mark.parametrize("dtype", [torch.float32])
705689
@pytest.mark.parametrize("kv_cache_dtype",
706690
["auto"]) # You can also test "fp8" if needed.
707-
@pytest.mark.parametrize("align_cache", [True, False])
708691
@pytest.mark.parametrize("device", CUDA_DEVICES)
709692
@torch.inference_mode()
710693
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
711694
num_blocks, max_seq_len, batch_size, dtype,
712-
kv_cache_dtype, align_cache, device):
695+
kv_cache_dtype, device):
713696
entry_size = kv_lora_rank + qk_rope_head_dim
714697
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
715-
kv_cache_dtype, device, align_cache)
698+
kv_cache_dtype, device)
716699
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
717700

718701
seq_len_tensor = torch.randint(0,

vllm/envs.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@
8484
VLLM_SERVER_DEV_MODE: bool = False
8585
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
8686
VLLM_MLA_DISABLE: bool = False
87-
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
8887
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
8988
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
9089
VLLM_RAY_BUNDLE_INDICES: str = ""
@@ -580,15 +579,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
580579
"VLLM_RAY_BUNDLE_INDICES":
581580
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
582581

583-
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
584-
# byte aligned for better performance, this increases the memory usage of
585-
# the cache. Currently this only affects MLA that results in non-256
586-
# byte aligned entries. This matches the alignment the CUDA runtime uses
587-
# for all allocations. Currently this primarily affects MLA, for most other
588-
# models the alignment is already naturally aligned to 256 bytes.
589-
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
590-
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
591-
592582
# In some system, find_loaded_library() may not work. So we allow users to
593583
# specify the path through environment variable VLLM_CUDART_SO_PATH.
594584
"VLLM_CUDART_SO_PATH":

vllm/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
827827
return torch.tensor([], dtype=dtype).element_size()
828828

829829

830-
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
831-
dtype_size = get_dtype_size(dtype)
832-
eles_per_256bytes = 256 // dtype_size
833-
return round_up(extent, eles_per_256bytes)
834-
835-
836830
# `collections` helpers
837831
def is_list_of(
838832
value: object,

vllm/worker/cache_engine.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""CacheEngine class for managing the KV cache."""
3-
from math import prod
43
from typing import List
54

65
import torch
76

8-
from vllm import envs
97
from vllm.attention import get_attn_backend
108
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
119
from vllm.logger import init_logger
12-
from vllm.platforms import current_platform
1310
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
14-
align_to_256bytes, get_dtype_size,
15-
is_pin_memory_available)
11+
get_dtype_size, is_pin_memory_available)
1612

1713
logger = init_logger(__name__)
1814

@@ -42,7 +38,6 @@ def __init__(
4238
self.num_attention_layers = model_config.get_num_layers_by_block_type(
4339
parallel_config, LayerBlockType.attention)
4440
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
45-
self.align_cache = self._align_cache(model_config)
4641

4742
self.block_size = cache_config.block_size
4843
self.num_gpu_blocks = cache_config.num_gpu_blocks
@@ -81,38 +76,18 @@ def _allocate_kv_cache(
8176
pin_memory = is_pin_memory_available() if device == "cpu" else False
8277
kv_cache: List[torch.Tensor] = []
8378

84-
# Align entries so they are 256 byte aligned for better performance
85-
# Primarily targets MLA as this typically only ends up having entries
86-
# be 128 byte aligned.
87-
if self.align_cache:
88-
# We assume the cache shape is:
89-
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
90-
# NOTE this assumption currently only holds for MLA so we only apply
91-
# this optimization when `use_mla` is true
92-
entry_shape = kv_cache_shape[2:]
93-
entry_size = prod(entry_shape)
94-
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
95-
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
96-
else:
97-
alloc_shape = kv_cache_shape
98-
9979
for _ in range(self.num_attention_layers):
10080
# null block in CpuGpuBlockAllocator requires at least that
10181
# block to be zeroed-out.
10282
# We zero-out everything for simplicity.
103-
layer_kv_cache = torch.zeros(alloc_shape,
83+
layer_kv_cache = torch.zeros(kv_cache_shape,
10484
dtype=self.dtype,
10585
pin_memory=pin_memory,
10686
device=device)
10787

108-
# If we allocated with padding for alignment reasons truncate the
109-
# shape while preserving the aligned stride
110-
if self.align_cache:
111-
layer_kv_cache = layer_kv_cache[..., :entry_size]
112-
11388
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
11489
# when entry_shape is higher than 1D
115-
kv_cache.append(layer_kv_cache.view(kv_cache_shape))
90+
kv_cache.append(layer_kv_cache)
11691
return kv_cache
11792

11893
def swap_in(self, src_to_dst: torch.Tensor) -> None:
@@ -128,14 +103,6 @@ def swap_out(self, src_to_dst: torch.Tensor) -> None:
128103
def copy(self, src_to_dsts: torch.Tensor) -> None:
129104
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
130105

131-
@staticmethod
132-
def _align_cache(model_config: ModelConfig):
133-
# Currently align_cache only applies to MLA models since the other
134-
# cache kernels haven't been updated yet to support non-continguous
135-
# tensors
136-
return model_config.use_mla and current_platform.is_cuda() \
137-
and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE
138-
139106
@staticmethod
140107
def get_cache_block_size(
141108
cache_config: CacheConfig,
@@ -153,9 +120,6 @@ def get_cache_block_size(
153120
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
154121

155122
key_cache_entry = num_heads * head_size
156-
if CacheEngine._align_cache(model_config):
157-
key_cache_entry = align_to_256bytes(key_cache_entry,
158-
model_config.dtype)
159123

160124
# For MLA there is no value cache, since the latent vector
161125
# is joint keys and values.

0 commit comments

Comments
 (0)