diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 833f46eaaf726..24134cfafffa4 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -4,6 +4,7 @@ #endif #include #include +#include // clang-format on #include @@ -19,6 +20,7 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) +using namespace sycl::ext::intel::esimd; template struct Float_Trait { @@ -139,6 +141,779 @@ inline float block_sum( item_ct1.get_sub_group(), sum, 0); } +// How about implement a first edition that can be used with non-chunked +// prefill requests, so that we can make sure the reference for heads is +// correct +template +void context_attention_kernel_v1( + void* query, void* key, void* value, const void* block_tables, + const float scale, const void* query_start_loc, const void* seq_lens, + const void* context_lens, const int block_size, + const int x, // x in kv_cache + void* out, // output + const int block_table_stride_batch, const int block_table_stride_seq, + const int query_stride_bs, const int query_stride_head, + const int query_stride_dim, const int k_cache_stride_tokens, + const int k_cache_stride_head, const int k_cache_stride_dim, + const int k_cache_stride_block_size, const int k_cache_stride_x, + const int v_cache_stride_tokens, const int v_cache_stride_head, + const int v_cache_stride_dim, const int v_cache_stride_block_size, + const int out_stride_tokens, const int out_stride_head, + const int num_queries_per_kv, const int max_input_length, + const int batch_size, const int num_heads) { + static_assert(GS * HD * sizeof(scalar_t) * 2 < 64 * 1024); + + const size_t key_slm_offset = 0; + const size_t value_slm_offset = GS * HD * sizeof(scalar_t); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + // Get the maximum seq_lens + sycl::range<3> global_size(batch_size, num_heads, + (max_input_length + GS - 1) / GS * GS); + sycl::range<3> local_size(1, 1, GS); + + auto cgf = [&](sycl::handler& handle) { + handle.parallel_for( + sycl::nd_range<3>(global_size, local_size), + [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { + slm_init(); + + const size_t bsz_idx = item.get_global_id(0); + const size_t head_idx = item.get_global_id(1); + // Assuming we have 32 query head and 8 kv_heads. Then + // num_queries_per_group should be 4 For head_idx 13, then + // kv_head_idx = 13 / 4 = 3, which is correct + const size_t kv_head_idx = head_idx / num_queries_per_kv; + const int32_t seq_idx = item.get_global_id(2); + const size_t gid = item.get_group(2); + const size_t tid = item.get_local_id(2); + + // const int64_t * seq_len = (const int64_t *) seq_lens; + const int32_t* seq_len = (const int32_t*)seq_lens; + int32_t seq_bound = seq_len[bsz_idx]; + + const int32_t* query_loc = (const int32_t*)query_start_loc; + // There is a possibility that the current token index pass + // over the seq_len, therefore: token_idx is the position in + // the query + int32_t token_idx = + query_loc[bsz_idx] + std::min(seq_idx, seq_bound - 1); + + const int32_t* context_len_pointer = (const int32_t*)context_lens; + + const int* block_tables_ptr = (const int*)block_tables; + const int* block_table = + block_tables_ptr + bsz_idx * block_table_stride_batch; + // I guess this context_len should be 0... + const int32_t context_len = context_len_pointer[bsz_idx]; + + // Position in the sequence + // context + seq_idx + // const int32_t token_position = + // context_len + std::min(seq_idx, seq_bound - 1); + const int32_t token_position = context_len + seq_idx; + + // static const CONSTANT char FMT[] = + // "Invoke target function...\n "; + + // sycl::ext::oneapi::experimental::printf(FMT); + // static const CONSTANT char FMT[] = + // "GroupID = %6d bsz_idx = %6d seq_len = %6d seq_idx = + // %6d" "local_id = " + // "%6d " + // "token_idx = %6d " + // "context_len = %6d " + // "v_cache_stride_head_dim = %6d " + // "token_position = %6d\n"; + // sycl::ext::oneapi::experimental::printf( + // FMT, gid, bsz_idx, seq_bound, seq_idx, tid, + // token_idx, context_len, v_cache_stride_dim, + // token_position); + + const scalar_t* query_head = (const scalar_t*)query + + token_idx * query_stride_bs + + head_idx * query_stride_head; + // Target output + scalar_t* out_head = + (scalar_t*)out + + (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + + head_idx * out_stride_head; + + int32_t context_groups = context_len / GS; + + // Each token load its query_row + simd query_row = + block_load(query_head) * scale; + simd accv = 0; + simd softmaxv = 0; + scalar_t max_attn = -sycl::detail::max_v(); + + // ################# Handle n * GS context part ###################### + int32_t n = context_len / GS; + int32_t context_offset = context_len % GS; + + for (int32_t group = 0; group < n; ++group) { + size_t target_key_position = group * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements, decided by x + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + barrier(); + + // Calculate QK^T for this group... + simd attnv; +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + attnv[r] = attn; + } + scalar_t new_max_attn = + std::max(hmax(attnv), max_attn); + scalar_t attn_exp = exp(max_attn - new_max_attn); + accv = accv * attn_exp; + softmaxv = softmaxv * attn_exp; + max_attn = new_max_attn; + const simd attn_expv = exp(attnv - max_attn); +#pragma unorll + for (size_t r = 0; r < GS; ++r) { + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + accv += value_row * attn_expv[r]; + } + softmaxv += attn_expv; + barrier(); + } + + // ########## End for handling context n * GS part ########### + + // ########## Handle n * GS ################ + for (size_t group = 0; group < gid; ++group) { + // 1. begins to load each position's key and value + size_t target_key_position = context_len + group * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + barrier(); + simd attnv; +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + attnv[r] = attn; + } + + scalar_t new_max_attn = + std::max(hmax(attnv), max_attn); + scalar_t attn_exp = exp(max_attn - new_max_attn); + accv = accv * attn_exp; + + softmaxv = softmaxv * attn_exp; + max_attn = new_max_attn; + const simd attn_expv = exp(attnv - max_attn); +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + accv += value_row * attn_expv[r]; + } + softmaxv += attn_expv; + barrier(); + } + + // ######### End of handle n * GS part ########## + + // ################ Handle offset part #################### + scalar_t softmax = + sycl::ext::intel::esimd::detail::sum( + softmaxv); + + // ########### handle context offset ############ + if (tid < context_offset) { + size_t target_key_position = n * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + // Seems to have an error here + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + } + + barrier(); + + if (token_position < seq_bound) { +#pragma unroll + for (size_t r = 0; r < context_offset; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + if (attn <= max_attn) { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(attn - max_attn); + accv += value_row * attn_exp; + softmax += attn_exp; + } else { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(max_attn - attn); + accv = accv * attn_exp + value_row; + softmax = softmax * attn_exp + 1; + max_attn = attn; + } + } + } + barrier(); + + // ############## handle seq offset ################# + if (token_position < seq_bound) { + const int64_t which_block = + static_cast(token_position / block_size); + const int64_t which_slot = + static_cast(token_position % block_size); + + const int64_t physical_block_number = + static_cast(block_table[which_block]); + + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + // [num_blocks, num_kv_heads, head_size, block_size] + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + } + barrier(); + + if (token_position < seq_bound) { + for (size_t r = 0; r <= tid; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + if (attn <= max_attn) { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(attn - max_attn); + accv += value_row * attn_exp; + softmax += attn_exp; + } else { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(max_attn - attn); + accv = accv * attn_exp + value_row; + softmax = softmax * attn_exp + 1; + max_attn = attn; + } + } + + if (softmax > 0) { + simd result = accv / softmax; + block_store(out_head, result); + } else { + simd result = 0; + block_store(out_head, result); + } + } + // ######## Ending of handling seq offset ########## + }); + }; + queue.submit(cgf); +} + +template +void context_attention_kernel_v2( + void* query, void* key, void* value, const void* block_tables, + const float scale, const void* query_start_loc, const void* seq_lens, + const void* context_lens, const int block_size, + const int x, // x in kv_cache + void* out, // output + const int block_table_stride_batch, const int block_table_stride_seq, + const int query_stride_bs, const int query_stride_head, + const int query_stride_dim, const int k_cache_stride_tokens, + const int k_cache_stride_head, const int k_cache_stride_dim, + const int k_cache_stride_block_size, const int k_cache_stride_x, + const int v_cache_stride_tokens, const int v_cache_stride_head, + const int v_cache_stride_dim, const int v_cache_stride_block_size, + const int out_stride_tokens, const int out_stride_head, + const int num_queries_per_kv, const int max_input_length, + const int batch_size, const int num_heads, const int num_tokens, + const int max_context_len) { + constexpr int BLOCK_SIZE = 16; + constexpr int NUM_THREADS = 128; + // Each wrap handles one context block, therefore, each thread_group_size is + // this. + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + // Each query, and key thread_group loads 16 bytes + // Assume TGS=4 then 16 / 4 / sizeof(half) = 2 + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using sycl_t = vllm::xpu::SyclTypeTrait::Type; + using Q_Vec = typename Vec::Type; + + // Assuming HD = 128, TGS = 2, then 128 / 2 / 2 = 32 + int num_vecs_per_thread = HD / THREAD_GROUP_SIZE / VEC_SIZE; + sycl_t* out_p = reinterpret_cast(out); + sycl_t* query_ptr = reinterpret_cast(query); + sycl_t* key_cache_ptr = reinterpret_cast(key); + sycl_t* value_cache_ptr = reinterpret_cast(value); + const int* query_loc_ptr = reinterpret_cast(query_start_loc); + const int* block_tables_ptr = reinterpret_cast(block_tables); + const int* context_lens_ptr = reinterpret_cast(context_lens); + const int* seq_lens_ptr = reinterpret_cast(seq_lens); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = + DIVIDE_ROUND_UP(max_context_len + 1 + max_input_length, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float); + // Python-side check in + // vllm.worker.worker._check_if_can_support_max_seq_len Keep that in + // sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + // WARN: we have changed this... + sycl::range<3> grid(batch_size, num_heads, max_input_length); + // One work-group that is executing on the device + sycl::range<3> block(1, 1, NUM_THREADS); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + auto cgf = [&](sycl::handler& handle) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(shared_mem_size), handle); + sycl::local_accessor q_vecs_acc_ct1( + sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), handle); + sycl::local_accessor red_smem_acc_ct1( + sycl::range<1>(2 * NUM_WARPS), handle); + + handle.parallel_for( + sycl::nd_range<3>(grid * block, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + const int bsz_idx = item_ct1.get_group(0); + const int seq_idx = item_ct1.get_group(2); + constexpr bool USE_PARTITIONING = false; + int context_len = context_lens_ptr[bsz_idx] + seq_idx; + const int seq_len = seq_lens_ptr[bsz_idx]; + uint8_t* dpct_local = dpct_local_acc_ct1.get_pointer(); + Q_Vec* q_vecs = q_vecs_acc_ct1.get_pointer(); + float* red_smem = red_smem_acc_ct1.get_pointer(); + + // output_stream << "Original context_len: " << + // context_lens_ptr[bsz_idx] << sycl::endl; output_stream << + // "Batch_idx: " << bsz_idx << " Seq_idx: " << seq_idx + // << " Context_len: " << context_len << " Original context_len: " + // << context_lens_ptr[bsz_idx] << " Seq_len: " << seq_len + // << " Max input length: " << max_input_length + // << sycl::endl; + if (context_len >= seq_len) { + return; + } + + context_len = context_len + 1; + + const int num_context_blocks = + DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = num_context_blocks; + + const int start_block_idx = 0; + const int end_block_idx = + MIN(start_block_idx + num_context_blocks, num_context_blocks); + + const int num_blocks = end_block_idx - start_block_idx; + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / + THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = item_ct1.get_local_id(2); + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + const int head_idx = item_ct1.get_group(1); + const int num_heads = item_ct1.get_group_range(1); + const int kv_head_idx = head_idx / num_queries_per_kv; + // TODO: consider alibi_slope later + constexpr int NUM_ELEMS_PER_THREAD = HD / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + const sycl_t* q_ptr = + query_ptr + (query_loc_ptr[bsz_idx] + seq_idx) * query_stride_bs + + head_idx * HD; + +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + // Loaded q_vecs + item_ct1.barrier(sycl::access::fence_space::local_space); + auto shared_mem = (char*)dpct_local; + float* logits = reinterpret_cast(shared_mem); + constexpr int x = 16 / sizeof(sycl_t); + float qk_max = -FLT_MAX; + const int* block_table = + block_tables_ptr + bsz_idx * block_table_stride_batch; + + // Loading key + for (int block_idx = start_block_idx + warp_idx; + block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = + block_idx * BLOCK_SIZE + physical_block_offset; + + Q_Vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const sycl_t* k_ptr = + key_cache_ptr + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + physical_block_offset * x; + + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the + // same thread group. Q_Vec_t + // q_vec_[NUM_VECS_PER_THREAD] = q_vecs + + // thread_group_offset * THREAD_GROUP_SIZE; + float qk = scale * + Qk_dot::template dot< + Q_Vec, NUM_VECS_PER_THREAD>( + q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, + k_vecs, item_ct1); + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the + // masked logits. + const bool mask = token_idx > context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); + } + } + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + /* + DPCT1096:38: The right-most dimension of the work-group used + in the SYCL kernel that calls this function may be less than + "32". The function "dpct::permute_sub_group_by_xor" may + return an unexpected result on the CPU device. Modify the + size of the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:39: The right-most dimension of the work-group used + in the SYCL kernel that calls this function may be less than + "32". The function "dpct::permute_sub_group_by_xor" may + return an unexpected result on the CPU device. Modify the + size of the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + qk_max = + dpct::select_from_sub_group(item_ct1.get_sub_group(), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = sycl::exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = + block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); + // Compute softmax. + const float inv_sum = 1.f / (exp_sum + 1e-6f); +#pragma unroll + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + constexpr int V_VEC_SIZE = MIN(16 / sizeof(sycl_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HD, NUM_ROWS_PER_ITER); + // NOTE(woosuk): We use FP32 for the accumulator for better + // accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + sycl_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; + block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. + // However, we cast it to int64 because int32 can lead to + // overflow when this variable is multiplied by large + // numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = + (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = + block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + vllm::from_float( + logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const sycl_t* v_ptr = + value_cache_ptr + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens + // that are out of the context, we should + // explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + sycl_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += vllm::dot(logits_vec, v_vec); + } + } + } + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:41: The right-most dimension of the work-group + used in the SYCL kernel that calls this function may be + less than "32". The function + "dpct::permute_sub_group_by_xor" may return an + unexpected result on the CPU device. Modify the size of + the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + acc += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), + acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory + // space for logits is reused for the output. + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + // Write the final output. + if (warp_idx == 0) { + sycl_t* out_ptr = + out_p + (query_loc_ptr[bsz_idx] + seq_idx) * out_stride_tokens + + head_idx * out_stride_head; + +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + vllm::from_float(*(out_ptr + row_idx), accs[i]); + } + } + } + }); + // Each thread_group handles one token + }; + queue.submit(cgf); +} + template < typename scalar_t, typename Q_Vec_t, @@ -1251,4 +2026,199 @@ void paged_attention_v2( query.scalar_type(), "paged_attention_xpu_v2_impl", [&] { CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); }); -} \ No newline at end of file +} + +torch::Tensor context_attention_forward_v2( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length) { + // Currently, only support fp16 here + int64_t num_tokens = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_dim = query.size(2); + int64_t batch_size = seq_lens.size(0); + int num_kv_heads = value.size(1); + + int key_dimension = key.dim(); + auto output = at::empty({query.size(0), query.size(1), query.size(2)}, + at::device(query.device()).dtype(query.dtype())); + + assert(key_dimension == 5); + assert(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type()); + assert(head_dim == 128); + assert(query.scalar_type() == at::ScalarType::Half); + + int query_stride_token = query.stride(0); + int query_stride_head = query.stride(1); + int query_stride_dim = query.stride(2); + const float attn_scale = 1 / std::sqrt((float)head_dim); + + assert(num_heads % num_kv_heads == 0); + int num_queries_per_kv = num_heads / num_kv_heads; + + + // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) + // value: [num_blocks, num_kv_heads, head_size, block_dim] + int block_size = value.size(3); + // Currently, only block_size 16 is supported... + assert(block_size == 16); + int x = key.size(4); + int block_table_stride_bsz = block_tables.stride(0); + int block_table_stride_seq = block_tables.stride(1); + int k_cache_stride_token = key.stride(0); + int k_cache_stride_head = key.stride(1); + int k_cache_stride_head_dim = key.stride(2); + int k_cache_stride_block = key.stride(3); + int k_cache_stride_x = key.stride(4); + + int v_cache_stride_token = value.stride(0); + int v_cache_stride_head = value.stride(1); + int v_cache_stride_head_dim = value.stride(2); + int v_cache_stride_block = value.stride(3); + switch(head_dim) { + case 128: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 64: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 80: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 96: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + default: throw std::runtime_error("unsupported head_dim"); + } + return output; +} + +torch::Tensor context_attention_forward_v1( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length) { + // Currently, only support fp16 + int64_t num_tokens = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_dim = query.size(2); + int64_t batch_size = seq_lens.size(0); + int num_kv_heads = value.size(1); + + int key_dimension = key.dim(); + auto output = at::empty({query.size(0), query.size(1), query.size(2)}, + at::device(query.device()).dtype(query.dtype())); + + // key should be in shape: + // 1. [num_tokens, num_kv_head, head_dim] + assert(key_dimension == 3 or key_dimension == 5); + assert(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type()); + assert(head_dim == 128); + assert(query.scalar_type() == at::ScalarType::Half); + + int query_stride_token = query.stride(0); + int query_stride_head = query.stride(1); + int query_stride_dim = query.stride(2); + const float attn_scale = 1 / std::sqrt((float)head_dim); + + assert(num_heads % num_kv_heads == 0); + int num_queries_per_kv = num_heads / num_kv_heads; + + // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) + // value: [num_blocks, num_kv_heads, head_size, block_dim] + int block_size = value.size(3); + int x = key.size(4); + int block_table_stride_bsz = block_tables.stride(0); + int block_table_stride_seq = block_tables.stride(1); + int k_cache_stride_token = key.stride(0); + int k_cache_stride_head = key.stride(1); + int k_cache_stride_head_dim = key.stride(2); + int k_cache_stride_block = key.stride(3); + int k_cache_stride_x = key.stride(4); + + int v_cache_stride_token = value.stride(0); + int v_cache_stride_head = value.stride(1); + int v_cache_stride_head_dim = value.stride(2); + int v_cache_stride_block = value.stride(3); + switch(head_dim) { + case 128: + vllm::context_attention_kernel_v1( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads); + break; + case 64: + vllm::context_attention_kernel_v1( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads); + break; + default: throw std::runtime_error("unsupported head_dim"); + } + return output; +} diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp index 4e7f2fa6bd80f..e42ae45c6a50c 100644 --- a/csrc/xpu/pybind.cpp +++ b/csrc/xpu/pybind.cpp @@ -75,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "awq_dequantize", &awq_dequantize, "dequant method for awq"); + + ops.def("context_attention_forward_v1", &context_attention_forward_v1, + "Context attention forward_v1"); + + ops.def("context_attention_forward_v2", &context_attention_forward_v2, + "Context attention forward_v2"); } diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h index 6125b19ac80b5..db7ceef1da343 100644 --- a/csrc/xpu/xpu_ops.h +++ b/csrc/xpu/xpu_ops.h @@ -40,6 +40,22 @@ void paged_attention_v2( int max_context_len, const c10::optional &alibi_slopes, const std::string& kv_cache_dtype, const float kv_scale); +torch::Tensor context_attention_forward_v1( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length); + +torch::Tensor context_attention_forward_v2( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length); + void copy_blocks( std::vector &key_caches, std::vector &value_caches, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 23b3e6ec8c0cc..4bd6276bd84fd 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -10,6 +10,7 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +import os _PARTITION_SIZE = 512 @@ -67,6 +68,12 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] seqlen_q: Optional[torch.Tensor] max_seqlen: Optional[int] + query_start_loc: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] + + + _cached_prefill_metadata: Optional["IpexAttnMetadata"] = None + _cached_decode_metadata: Optional["IpexAttnMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -78,21 +85,65 @@ def __post_init__(self): @property def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self + if self.num_prefills == 0: + return None - return None + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = IpexAttnMetadata( + is_prompt=self.is_prompt, + seqlen_q=self.seqlen_q, + max_seqlen=self.max_seqlen, + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + # seq_start_loc=None, + context_lens=self.context_lens[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + ) + return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + if self.num_decode_tokens == 0: return None - return self + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = IpexAttnMetadata( + is_prompt=self.is_prompt, + seqlen_q=self.seqlen_q, + max_seqlen=self.max_seqlen, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + # max_query_len=None, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + # seq_start_loc=None, + context_lens=None, + block_tables=self.block_tables[self.num_prefills:], + ) + return self._cached_decode_metadata from torch.nn.functional import scaled_dot_product_attention @@ -244,56 +295,52 @@ def forward( v_scale, ) - if attn_metadata.is_prompt: - assert attn_metadata.seq_lens is not None - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + # New added code-segment + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert query.shape[0] == num_prefill_tokens + num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.seq_lens is not None + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + prefill_meta.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, + prefill_meta.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.seq_lens) - attn_metadata.attn_bias = att_masks - - # output = torch.empty( - # (num_tokens, self.num_heads, self.head_size), - # dtype=query.dtype, - # device=query.device) - # ipex_ops.varlen_attention(query, - # key, - # value, - # output, - # attn_metadata.seqlen_q, - # attn_metadata.seqlen_q, - # attn_metadata.max_seqlen, - # attn_metadata.max_seqlen, - # pdropout=0.0, - # softmax_scale=self.scale, - # zero_tensors=False, - # is_causal=True, - # return_softmax=False, - # gen_=None) - - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype, device=query.device) + att_masks = [None] * len(prefill_meta.seq_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - for seq_len, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): + for seq_len, mask in zip(prefill_meta.seq_lens, + prefill_meta.attn_bias): end = start + seq_len if use_sdp_causal(self.head_size, query): import xe_addons @@ -318,16 +365,27 @@ def forward( output[start:end, :, :] = sub_out start = end else: - # prefix-enabled attention - raise RuntimeError( - "IPEX backend doesn't support prefix decoding.") + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + import vllm._C.ops + value = os.environ.get('USE_CONTEXT_V1') + if self.head_size == 128 and value is None: + out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + else: + out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - max_seq_len = attn_metadata.max_decode_seq_len - output = torch.empty_like(query) + max_seq_len = decode_meta.max_decode_seq_len + out = torch.empty_like(decode_query) block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape + # print(f"In decoding, the shape is:{decode_query.shape}") + num_seqs, num_heads, head_size = decode_query.shape max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use @@ -343,14 +401,14 @@ def forward( if use_v1: # Run PagedAttention V1. ipex_ops.paged_attention_v1( - output, - query, + out, + decode_query, key_cache, value_cache, self.num_kv_heads, self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -373,17 +431,17 @@ def forward( ) max_logits = torch.empty_like(exp_sums) ipex_ops.paged_attention_v2( - output, + out, exp_sums, max_logits, tmp_output, - query, + decode_query, key_cache, value_cache, self.num_kv_heads, self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -391,8 +449,9 @@ def forward( k_scale, v_scale, ) + output[num_prefill_tokens:] = out - # Reshape the output tensor. + # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 025449cfe4853..5f5154dbab5e6 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, Mapping import torch @@ -26,6 +27,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm.attention.backends.utils import is_block_tables_empty if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -237,35 +239,210 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + + seq_lens: List[int] = [] + # Prefill's seq_len: query_length + chunked_prefill length + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + # One for each sequence, physical blocks + block_tables: List[List[int]] = [] + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + if len(seq_group_metadata_list) == 0: + return None + + assert self.sliding_window is None, "TODO: support sliding window later" + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + # is_prompt indicates that it is still in prompt states + # TODO: remove this is_prompt + is_prompt = seq_group_metadata.is_prompt + + # Iterate over all the seqs in the seq_group + for seq_id in seq_ids: + # Check for prefix caching + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None or computed_block_nums == [])): + raise RuntimeError("chunked prefill cannot be used with prefix caching") + seq_data = seq_group_metadata.seq_data[seq_id] + # Context_len: how many tokens that have been computed + if is_prompt: + if computed_block_nums is not None: + context_len = len(computed_block_nums) * self.block_size + else: + context_len = 0 + else: + context_len = seq_data.get_len() - 1 + + # Get tokens for this sequence + # For prefill, the seq_len will be the second one. + # For decoding, the seq_len will be the first one. + seq_len = min(seq_data.get_len(), context_len + seq_group_metadata.token_chunk_size) + + if is_prompt: + tokens = seq_data.get_token_ids()[context_len: seq_len] + else: + # Last token + tokens = [seq_data.get_last_token_id()] + + # FIXME: add prefix caching + if (computed_block_nums is not None or self.scheduler_config.chunked_prefill_enabled or not is_prompt): + # Chunked prefill or decoding + # For chunked prefill, the block tables may not be None + if seq_group_metadata.block_tables is not None: + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = [] + else: + # Prefill without chunked prefill + block_table = [] + block_tables.append(block_table) + # Total seq_lens + seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, "Wrong query length in decoding" + num_decode_tokens += 1 + decode_seq_lens.append(seq_len) + if is_block_tables_empty(seq_group_metadata.block_tables): + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # seq_id: List[int] + block_table = seq_group_metadata.block_tables[seq_id] + + # TODO: add sliding window + for i in range(context_len, seq_len): + # if i < start_idx: + # slot_mapping.append(_PAD_SLOT_ID) + # continue + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + # slot_mapping is when we flatteren the blocks, and see which block it is located + # block_table is a logical -> to physical transition... + # i // block_size is the logical block number + slot_mapping.append(block_number * self.block_size + block_offset) + max_query_len = max(query_lens) + max_decode_seq_len = max(decode_seq_lens, default=0) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + # What is the usage of this seq_start_loc? + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + # Generate attn_metadata + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + attn_metadata = self.attn_backend.make_metadata( + # FIXME: Later maybe we can get rid of this parameter + is_prompt=is_prompt, #1 + num_prefills=num_prefills, # 6 + slot_mapping=slot_mapping_tensor, # 2 + num_prefill_tokens=num_prefill_tokens, # 7 + num_decode_tokens=num_decode_tokens, # 8 + seq_lens=seq_lens_tensor, # 3 + seqlen_q=seqlen_q, # 4 + # max_seqlen=max_seqlen, # 5 + max_seqlen=max(query_lens), + seq_lens_tensor=seq_lens_tensor, # 9 + # max_query_len=max_query_len, + max_decode_seq_len=max_decode_seq_len, # 10 + query_start_loc=query_start_loc, + # seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables if (self.scheduler_config.chunked_prefill_enabled or not is_prompt or self.cache_config.enable_prefix_caching) else torch.tensor([], device=self.device, dtype=torch.int) # 11 + ) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, # subquery_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill # just use seq_lens instead. - seq_lens, + query_lens, self.device, pin_memory=False) - - return ModelInputForXPU(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - virtual_engine=virtual_engine) + return ModelInputForXPU( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=None, + virtual_engine=virtual_engine + ) def _prepare_decode( self,