Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break cycle between the attention implementations and KV cache #2627

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,21 @@
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
Expand All @@ -39,8 +33,6 @@
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
Expand Down
151 changes: 49 additions & 102 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
Expand All @@ -11,35 +12,10 @@
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512

try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)


def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
Expand Down Expand Up @@ -80,7 +56,7 @@ def paged_attention(

return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
Expand All @@ -98,8 +74,8 @@ def paged_attention(
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
Expand Down Expand Up @@ -135,8 +111,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand Down Expand Up @@ -168,8 +144,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand Down Expand Up @@ -218,52 +194,42 @@ def paged_attention(

SUPPORTS_WINDOWING = V2

if ATTENTION == "flashinfer":

def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)

return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
query.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)

elif V2:

def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
kv_cache.key,
kv_cache.value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
Expand All @@ -284,19 +250,7 @@ def attention(
None,
)[0]

else:

def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
else:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
Expand All @@ -305,36 +259,36 @@ def attention(
raise NotImplementedError("softcap is only available with flash attn v2")

# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
if key.shape[1] != query.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
if key.shape[1] == 1:
key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
original_shape = key.shape
key = (
key.unsqueeze(2)
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
if value.shape[1] != query.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
if value.shape[1] == 1:
value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
original_shape = value.shape
value = (
value.unsqueeze(2)
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)

out = torch.empty_like(q)
out = torch.empty_like(query)
flash_attn_cuda.fwd(
q,
k,
v,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -351,15 +305,8 @@ def attention(
return out


# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
39 changes: 13 additions & 26 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional

SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False


def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -42,22 +44,9 @@ def attention(
return out


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)


def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
Expand All @@ -69,8 +58,8 @@ def paged_attention(
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand All @@ -83,9 +72,7 @@ def paged_attention(


__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
Loading
Loading