Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gemma model #68

Merged
merged 6 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions csrc/xpu/attention_xpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
using namespace sycl::ext::intel::esimd;

template<typename T>
static inline T attn_softcapping(T qk, float attn_logit_softcapping) {
qk = qk / attn_logit_softcapping;
qk = (sycl::exp(qk) - sycl::exp(-qk)) / (sycl::exp(qk) + sycl::exp(-qk));
qk = qk * attn_logit_softcapping;
return qk;
}

template <typename T>
struct Float_Trait {
using Type = T;
Expand Down Expand Up @@ -1274,6 +1282,7 @@ void paged_attention_kernel(
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float attn_logit_softcapping,
const sycl::nd_item<3>& item_ct1,
uint8_t* dpct_local,
Q_Vec_t* q_vecs,
Expand Down Expand Up @@ -1429,6 +1438,10 @@ void paged_attention_kernel(
qk +=
(alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;

// Add the attn_logit_softcapp if given.
if (attn_logit_softcapping != 0.0) {
qk = attn_softcapping(qk, attn_logit_softcapping);
}
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
Expand Down Expand Up @@ -1679,6 +1692,7 @@ void paged_attention_v1_kernel(
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float attn_logit_softcapping,
const sycl::nd_item<3>& item_ct1,
uint8_t* dpct_local,
Q_Vec_t* q_vecs,
Expand All @@ -1705,6 +1719,7 @@ void paged_attention_v1_kernel(
q_stride,
kv_block_stride,
kv_head_stride,
attn_logit_softcapping,
item_ct1,
dpct_local,
q_vecs,
Expand Down Expand Up @@ -1751,6 +1766,7 @@ void paged_attention_v1_kernel(
auto q_stride_ct10 = q_stride; \
auto kv_block_stride_ct11 = kv_block_stride; \
auto kv_head_stride_ct12 = kv_head_stride; \
auto attn_logit_softcapping_ct13 = attn_logit_softcapping; \
\
cgh.parallel_for( \
sycl::nd_range<3>(grid * block, block), \
Expand All @@ -1775,6 +1791,7 @@ void paged_attention_v1_kernel(
q_stride_ct10, \
kv_block_stride_ct11, \
kv_head_stride_ct12, \
attn_logit_softcapping_ct13, \
item_ct1, \
dpct_local_acc_ct1.get_pointer(), \
q_vecs_acc_ct1.get_pointer(), \
Expand All @@ -1793,7 +1810,8 @@ void paged_attention_xpu_v1_impl_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const float attn_logit_softcapping) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand Down Expand Up @@ -1907,7 +1925,8 @@ void paged_attention_xpu_v1_impl_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
attn_logit_softcapping);

#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
Expand Down Expand Up @@ -2101,6 +2120,7 @@ void paged_attention_v2_kernel(
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float attn_logit_softcapping,
const sycl::nd_item<3>& item_ct1,
uint8_t* dpct_local,
Q_Vec_t* q_vecs,
Expand Down Expand Up @@ -2128,6 +2148,7 @@ void paged_attention_v2_kernel(
q_stride,
kv_block_stride,
kv_head_stride,
attn_logit_softcapping,
item_ct1,
dpct_local,
q_vecs,
Expand Down Expand Up @@ -2157,6 +2178,7 @@ void paged_attention_v2_kernel(
auto q_stride_ct12 = q_stride; \
auto kv_block_stride_ct13 = kv_block_stride; \
auto kv_head_stride_ct14 = kv_head_stride; \
auto attn_logit_softcapping_ct15 = attn_logit_softcapping; \
\
cgh.parallel_for( \
sycl::nd_range<3>(grid * block, block), \
Expand Down Expand Up @@ -2184,6 +2206,7 @@ void paged_attention_v2_kernel(
q_stride_ct12, \
kv_block_stride_ct13, \
kv_head_stride_ct14, \
attn_logit_softcapping_ct15, \
item_ct1, \
dpct_local_acc_ct1.get_pointer(), \
q_vecs_acc_ct1.get_pointer(), \
Expand Down Expand Up @@ -2243,7 +2266,8 @@ void paged_attention_v2_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const float attn_logit_softcapping) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand Down Expand Up @@ -2402,7 +2426,8 @@ void paged_attention_v2_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
attn_logit_softcapping);

#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
Expand Down Expand Up @@ -2435,7 +2460,8 @@ void paged_attention_v1(
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
const float kv_scale) {
const float kv_scale,
const float attn_logit_softcapping) {
VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY(
query.scalar_type(), "paged_attention_xpu_v1_impl", [&] {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
Expand All @@ -2458,7 +2484,8 @@ void paged_attention_v2(
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
const float kv_scale) {
const float kv_scale,
const float attn_logit_softcapping) {
VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY(
query.scalar_type(), "paged_attention_xpu_v2_impl", [&] {
CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t);
Expand Down
4 changes: 2 additions & 2 deletions csrc/xpu/xpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ void paged_attention_v1(
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string& kv_cache_dtype, const float kv_scale);
const std::string& kv_cache_dtype, const float kv_scale, const float attn_logit_softcapping);

void paged_attention_v2(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string& kv_cache_dtype, const float kv_scale);
const std::string& kv_cache_dtype, const float kv_scale, const float attn_logit_softcapping);

torch::Tensor context_attention_forward_v1(
torch::Tensor query, // [num_tokens, num_kv_head, head_dim]
Expand Down
6 changes: 4 additions & 2 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def paged_attention_v1(
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
logits_soft_cap: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
Expand All @@ -86,7 +87,7 @@ def paged_attention_v1(
key_cache.view_as(value_cache),
value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype, k_scale)
max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap)

@staticmethod
def paged_attention_v2(
Expand All @@ -107,6 +108,7 @@ def paged_attention_v2(
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
logits_soft_cap: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
Expand All @@ -120,7 +122,7 @@ def paged_attention_v2(
key_cache.view_as(value_cache),
value_cache, num_kv_heads, scale, block_tables,
context_lens, block_size,
max_context_len, alibi_slopes,kv_cache_dtype, k_scale)
max_context_len, alibi_slopes,kv_cache_dtype, k_scale, logits_soft_cap)

@staticmethod
def rotary_embedding(
Expand Down
48 changes: 31 additions & 17 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,17 @@ def _make_attention_mask(
return mask


def use_sdp_causal(head_dim, query_states):
def use_sdp_causal(head_dim, query_states, logits_soft_cap):
return (
head_dim in [-1, 64, 80, 96, 128] # for now
and query_states.device.type == "xpu" # GPU
(logits_soft_cap != 0 # for gemma model
or head_dim in [-1, 64, 80, 96, 128]) # for now
and query_states.device.type == "xpu" # GPU
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
)

def use_gqa_kernel(num_heads, num_kv_heads):
def use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap):
kv_cache_format = os.environ.get('USE_VLLM_KVCACHE')
if kv_cache_format is None and num_heads != num_kv_heads:
if kv_cache_format is None and num_heads != num_kv_heads and head_size in [128, 96, 80, 64] and logits_soft_cap == 0:
return True
else:
return False
Expand All @@ -238,8 +239,6 @@ def __init__(
if blocksparse_params is not None:
raise ValueError(
"IPEX backend does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("IPEX backend does not support logits_soft_cap.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -254,6 +253,9 @@ def __init__(
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if logits_soft_cap is None:
logits_soft_cap = 0.0
self.logits_soft_cap = logits_soft_cap

supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
Expand Down Expand Up @@ -333,8 +335,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

using_gqa_kernel = use_gqa_kernel(self.num_heads, self.num_kv_heads)

using_gqa_kernel = use_gqa_kernel(self.num_heads, self.num_kv_heads, self.head_size, self.logits_soft_cap)
if kv_cache is not None:
if using_gqa_kernel:
key_cache, value_cache = self.split_kv_cache_ipexllm(
Expand Down Expand Up @@ -435,17 +436,27 @@ def forward(
for seq_len, mask in zip(prefill_meta.seq_lens,
prefill_meta.attn_bias):
end = start + seq_len
if self.alibi_slopes is None and use_sdp_causal(self.head_size, query):
if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap):
import xe_addons
if mask is not None:
mask = mask.unsqueeze(0)
sub_out = xe_addons.sdp_causal(
query[None, :, start:end, :].contiguous(),
key[None, :, start:end, :].contiguous(),
value[None, :, start:end, :].contiguous(),
mask,
scale).squeeze(0).movedim(
query.dim() - 2, 0)
if self.logits_soft_cap == 0 or self.head_size != 256:
sub_out = xe_addons.sdp_causal(
query[None, :, start:end, :].contiguous(),
key[None, :, start:end, :].contiguous(),
value[None, :, start:end, :].contiguous(),
mask,
scale).squeeze(0).movedim(
query.dim() - 2, 0)
else:
sub_out = xe_addons.gemma2_sdp_causal(
query[None, :, start:end, :].contiguous(),
key[None, :, start:end, :].contiguous(),
value[None, :, start:end, :].contiguous(),
mask,
self.logits_soft_cap,
self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
else:
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[None, :, start:end, :],
Expand Down Expand Up @@ -496,6 +507,7 @@ def forward(

bsz = len(decode_meta.seq_lens)
import vllm._C.ops

if using_gqa_kernel:
block_size = value_cache.shape[2]
vllm._C.ops.paged_attention_gqa(
Expand Down Expand Up @@ -534,6 +546,7 @@ def forward(
self.kv_cache_dtype,
k_scale,
v_scale,
self.logits_soft_cap,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -567,6 +580,7 @@ def forward(
self.kv_cache_dtype,
k_scale,
v_scale,
self.logits_soft_cap,
)
output[num_prefill_tokens:] = out

Expand Down