Skip to content

Commit

Permalink
Break cycle between the attention implementations and KV cache
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 9, 2024
1 parent c56df2d commit 89dd19f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 69 deletions.
4 changes: 0 additions & 4 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
Expand All @@ -36,7 +33,6 @@
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
Expand Down
25 changes: 0 additions & 25 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,6 @@
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512

try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)


def paged_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -333,5 +309,4 @@ def attention(
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
13 changes: 0 additions & 13 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ def attention(
return out


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)


def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
Expand Down Expand Up @@ -87,5 +75,4 @@ def paged_attention(
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
41 changes: 38 additions & 3 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,41 @@ def store(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
from text_generation_server.layers.attention import reshape_and_cache

reshape_and_cache(key, value, key_cache, value_cache, slots)
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)


def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "cuda":
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
)
24 changes: 0 additions & 24 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
Expand All @@ -28,28 +27,6 @@
)
use_rocm_custom_paged_attn = False

try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)


def paged_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -302,5 +279,4 @@ def attention(
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

0 comments on commit 89dd19f

Please sign in to comment.