Skip to content

Commit

Permalink
Add memory tiering for CC (vllm-project#3)
Browse files Browse the repository at this point in the history
* Add example disk swap config. Add unit tests for CC with memory tiering

* Layered transfer for DRAM. Transfer in cuda streams

* Fix the missing arg

* Fix context caching online serving

This commit enables layered transmission for DRAM first. Now the
transmission is done in different cuda streams. xformers, flash infer
and flash attention are supported. Optimized transfer for disk
is still pending.

Cherry-pick Yangshen's commit

---------

Co-authored-by: yangshen <yangshen.d@outlook.com>
  • Loading branch information
PanJason and TKONIY authored Sep 20, 2024
1 parent db31b89 commit f1123cd
Show file tree
Hide file tree
Showing 28 changed files with 1,731 additions and 385 deletions.
6 changes: 6 additions & 0 deletions swap.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"localdisk": {
"size": "1G",
"path": "/tmp/swap"
}
}
2 changes: 1 addition & 1 deletion tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping, _ = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_swap():
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping, _ = block_manager.swap_out(seq_group)
assert [x[0] for x in mapping] == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
Expand Down Expand Up @@ -373,7 +373,7 @@ def test_swap_encoder_decoder():
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping, _ = block_manager.swap_out(seq_group)
assert [x[0] for x in mapping] == gpu_blocks
#assert list(mapping.keys()) == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
Expand Down
289 changes: 281 additions & 8 deletions tests/core/test_scheduler.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

from vllm import SamplingParams
from vllm.caching_params import CachingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup

Expand All @@ -16,6 +17,7 @@ def create_dummy_prompt(
use_beam_search: bool = False,
best_of: int = 1,
prompt_range: Optional[Tuple[int, int]] = None,
caching_params: Optional[CachingParams] = None,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = prompt_length
Expand All @@ -40,7 +42,9 @@ def create_dummy_prompt(
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
lora_request=lora_request)
lora_request=lora_request,
caching_params=caching_params.clone()
if caching_params is not None else None)

return prompt, seq_group

Expand Down
7 changes: 7 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def copy_blocks(
) -> None:
pass

@staticmethod
def copy_blocks_one_layer(
kv_cache: torch.Tensor,
src_to_dists: torch.Tensor,
) -> None:
pass


def test_model_runner_input():
sampling_metadata = SamplingMetadata(
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def add_kv_cache_for_layered_transfer(
blocks_to_swap_out: Optional[torch.Tensor] = None,
blocks_to_copy: Optional[torch.Tensor] = None,
gpu_caches: Optional[List[torch.Tensor]] = None,
cpu_caches: Optional[List[torch.Tensor]] = None):
cpu_caches: Optional[List[torch.Tensor]] = None,
cuda_stream: Optional[torch.cuda.Stream] = None):
raise NotImplementedError


Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def add_kv_cache_for_layered_transfer(
blocks_to_swap_out: Optional[torch.Tensor] = None,
blocks_to_copy: Optional[torch.Tensor] = None,
gpu_caches: Optional[List[torch.Tensor]] = None,
cpu_caches: Optional[List[torch.Tensor]] = None):
cpu_caches: Optional[List[torch.Tensor]] = None,
cuda_stream: Optional[torch.cuda.Stream] = None):
self.enable_layered_transfer = True
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
Expand All @@ -266,6 +267,7 @@ def add_kv_cache_for_layered_transfer(
self.gpu_caches = gpu_caches
self.num_hidden_layers = num_hidden_layers
self.current_layer = 1
self.cuda_stream = cuda_stream


class BlocksparseFlashAttentionMetadataBuilder(
Expand Down
63 changes: 51 additions & 12 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def add_kv_cache_for_layered_transfer(
blocks_to_swap_out: Optional[torch.Tensor] = None,
blocks_to_copy: Optional[torch.Tensor] = None,
gpu_caches: Optional[List[torch.Tensor]] = None,
cpu_caches: Optional[List[torch.Tensor]] = None):
cpu_caches: Optional[List[torch.Tensor]] = None,
cuda_stream: Optional[torch.cuda.Stream] = None):
self.enable_layered_transfer = True
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
Expand All @@ -221,6 +222,7 @@ def add_kv_cache_for_layered_transfer(
self.gpu_caches = gpu_caches
self.num_hidden_layers = num_hidden_layers
self.current_layer = 1
self.cuda_stream = cuda_stream


class FlashAttentionMetadataBuilder(
Expand Down Expand Up @@ -489,32 +491,69 @@ def forward(

# First fetch the next layer
if attn_metadata.enable_layered_transfer:
# Wait for the previous layer to finish.
if attn_metadata.cuda_stream is not None:
dev = kv_cache.device
torch.cuda.default_stream(dev).wait_stream(
attn_metadata.cuda_stream)

if attn_metadata.current_layer < attn_metadata.num_hidden_layers:
# Swap out
if (attn_metadata.blocks_to_swap_out is not None
and attn_metadata.blocks_to_swap_out.numel() > 0
and attn_metadata.gpu_caches is not None
and attn_metadata.cpu_caches is not None):
FlashAttentionBackend.swap_blocks(
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.cpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
FlashAttentionBackend.swap_blocks(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)
else:
FlashAttentionBackend.swap_blocks(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)
if (attn_metadata.blocks_to_copy is not None
and attn_metadata.blocks_to_copy.numel() > 0
and attn_metadata.gpu_caches is not None):
FlashAttentionBackend.copy_blocks_one_layer(
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_copy)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
FlashAttentionBackend.copy_blocks_one_layer(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_copy)
else:
FlashAttentionBackend.copy_blocks_one_layer(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_copy)
if (attn_metadata.blocks_to_swap_in is not None
and attn_metadata.blocks_to_swap_in.numel() > 0
and attn_metadata.gpu_caches is not None
and attn_metadata.cpu_caches is not None):
FlashAttentionBackend.swap_blocks(
attn_metadata.cpu_caches[attn_metadata.current_layer],
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
FlashAttentionBackend.swap_blocks(
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)
else:
FlashAttentionBackend.swap_blocks(
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)
attn_metadata.current_layer += 1
elif attn_metadata.current_layer == attn_metadata.num_hidden_layers:
# NOTE: Prevent it from being freed but is it necessary?
attn_metadata.current_layer = 1

if attn_type != AttentionType.DECODER:
Expand Down
65 changes: 53 additions & 12 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def add_kv_cache_for_layered_transfer(
blocks_to_swap_out: Optional[torch.Tensor] = None,
blocks_to_copy: Optional[torch.Tensor] = None,
gpu_caches: Optional[List[torch.Tensor]] = None,
cpu_caches: Optional[List[torch.Tensor]] = None):
cpu_caches: Optional[List[torch.Tensor]] = None,
cuda_stream: Optional[torch.cuda.Stream] = None):
self.enable_layered_transfer = True
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
Expand All @@ -229,6 +230,7 @@ def add_kv_cache_for_layered_transfer(
self.gpu_caches = gpu_caches
self.num_hidden_layers = num_hidden_layers
self.current_layer = 1
self.cuda_stream = cuda_stream


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
Expand Down Expand Up @@ -504,32 +506,71 @@ def forward(

# First fetch the next layer
if attn_metadata.enable_layered_transfer:
# Wait for the previous layer to finish.
if attn_metadata.cuda_stream is not None and kv_cache is not None:
dev = kv_cache.device
torch.cuda.default_stream(dev).wait_stream(
attn_metadata.cuda_stream)
if attn_metadata.current_layer < attn_metadata.num_hidden_layers:
# Swap out
if (attn_metadata.blocks_to_swap_out is not None
and attn_metadata.blocks_to_swap_out.numel() > 0
and attn_metadata.gpu_caches is not None
and attn_metadata.cpu_caches is not None):
PagedAttention.swap_blocks(
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.cpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
PagedAttention.swap_blocks(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)
else:
PagedAttention.swap_blocks(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_out)

if (attn_metadata.blocks_to_copy is not None
and attn_metadata.blocks_to_copy.numel() > 0
and attn_metadata.gpu_caches is not None):
PagedAttention.copy_blocks_one_layer(
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_copy)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
PagedAttention.copy_blocks_one_layer(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_copy)
else:
PagedAttention.copy_blocks_one_layer(
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_copy)

if (attn_metadata.blocks_to_swap_in is not None
and attn_metadata.blocks_to_swap_in.numel() > 0
and attn_metadata.gpu_caches is not None
and attn_metadata.cpu_caches is not None):
PagedAttention.swap_blocks(
attn_metadata.cpu_caches[attn_metadata.current_layer],
attn_metadata.gpu_caches[attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)
if attn_metadata.cuda_stream is not None:
with torch.cuda.stream(attn_metadata.cuda_stream):
PagedAttention.swap_blocks(
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)
else:
PagedAttention.swap_blocks(
attn_metadata.cpu_caches[
attn_metadata.current_layer],
attn_metadata.gpu_caches[
attn_metadata.current_layer],
attn_metadata.blocks_to_swap_in)

attn_metadata.current_layer += 1
elif attn_metadata.current_layer == attn_metadata.num_hidden_layers:
# NOTE: Prevent it from being freed but is it necessary?
attn_metadata.current_layer = 1

assert k_scale == 1.0 and v_scale == 1.0, (
Expand Down
Loading

0 comments on commit f1123cd

Please sign in to comment.