diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 399e418c464b..844b237381a0 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -30,7 +30,16 @@ from tvm.target import Target from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func -from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache +from .tree_attn import ( + tree_attn, + tree_attn_cpu, + tree_attn_with_paged_kv_cache, + tree_attn_with_paged_kv_cache_cpu, +) + + +def _var_cpu(dtype): + return T.alloc_buffer((1,), dtype) def get_max_num_threads_per_block(target: Target) -> int: @@ -371,23 +380,230 @@ def __init__( # pylint: disable=too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged"), - bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), - bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), - rope_ext_factors, - rx.PrimValue(enable_disaggregation), # fmt: on # pylint: enable=line-too-long ] + + if str(target.kind) == "llvm": + args.extend( + [ + bb.add_func( + _attention_prefill_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + ), + "tir_attention_prefill_cpu", + ), + bb.add_func( + _attention_decode_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + ), + "tir_attention_decode_cpu", + ), + bb.add_func( + _attention_prefill_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + ), + "tir_attention_prefill_cpu_sliding_window", + ), + bb.add_func( + _attention_decode_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + ), + "tir_attention_decode_cpu_sliding_window", + ), + bb.add_func( + _attention_prefill_ragged_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_ragged_cpu", + ), + bb.add_func( + _merge_state_inplace_cpu(dtype), + "tir_attention_merge_state_cpu", + ), + bb.add_func( + llama_rope_with_position_map( + rope_theta, + rope_scale, + head_dim, + num_attention_heads, + num_key_value_heads, + dtype, + rope_scaling, + rotary_dim, + ), + "tir_split_rotary", + ), + bb.add_func( + _copy_single_page_cpu(num_key_value_heads, page_size, head_dim, dtype), + "kv_cache_copy_single_page_cpu", + ), + bb.add_func( + _kv_cache_debug_get_kv( + num_hidden_layers, num_key_value_heads, head_dim, dtype + ), + "kv_cache_debug_get_kv", + ), + bb.add_func( + _compact_kv_copy_cpu(num_key_value_heads, head_dim, dtype), + "kv_cache_compact_kv_copy_cpu", + ), + bb.add_func( + tree_attn_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_with_tree_mask_cpu", + ), + bb.add_func( + tree_attn_with_paged_kv_cache_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu", + ), + rope_ext_factors, + rx.PrimValue(enable_disaggregation), + ] + ) + else: + args.extend( + [ + bb.add_func( + _attention_prefill( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + target, + ), + "tir_attention_prefill", + ), + bb.add_func( + _attention_decode( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + target, + ), + "tir_attention_decode", + ), + bb.add_func( + _attention_prefill( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + target, + ), + "tir_attention_prefill_sliding_window", + ), + bb.add_func( + _attention_decode( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + target, + ), + "tir_attention_decode_sliding_window", + ), + bb.add_func( + _attention_prefill_ragged( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_ragged", + ), + bb.add_func( + _merge_state_inplace(num_attention_heads, head_dim, dtype, target), + "tir_attention_merge_state", + ), + bb.add_func( + llama_rope_with_position_map( + rope_theta, + rope_scale, + head_dim, + num_attention_heads, + num_key_value_heads, + dtype, + rope_scaling, + rotary_dim, + ), + "tir_split_rotary", + ), + bb.add_func( + _copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), + "kv_cache_copy_single_page", + ), + bb.add_func( + _kv_cache_debug_get_kv( + num_hidden_layers, num_key_value_heads, head_dim, dtype + ), + "kv_cache_debug_get_kv", + ), + bb.add_func( + _compact_kv_copy(num_key_value_heads, head_dim, dtype, target), + "kv_cache_compact_kv_copy", + ), + bb.add_func( + tree_attn( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_with_tree_mask", + ), + bb.add_func( + tree_attn_with_paged_kv_cache( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_with_tree_mask_with_paged_kv_cache", + ), + rope_ext_factors, + rx.PrimValue(enable_disaggregation), + ] + ) + super().__init__( _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create_reduced", @@ -553,6 +769,161 @@ def _get_seq_offset(pos, seq_id, length_info, sliding_window): ) +def _attention_prefill_cpu(h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any]): + global_symbol = "batch_prefill_paged_kv_cpu" + if sliding_window: + global_symbol += "_sliding_window" + + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv_cpu( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) + + + for h_qo in T.serial(h_q): + for b_idx in T.serial(batch_size): + with T.block("attn"): + O_local = T.alloc_buffer((d, ), "float32") + Q_local = T.alloc_buffer((d, ), "float32") + K_local = T.alloc_buffer((d, ), "float32") + V_local = T.alloc_buffer((d, ), "float32") + + kv_chunk_len = T.alloc_buffer((1, ), "int32") + + m_val = T.alloc_buffer((1, ), "float32") + new_m = T.alloc_buffer((1, ), "float32") + d_val = T.alloc_buffer((1, ), "float32") + S_val = T.alloc_buffer((1, ), "float32") + scale_O = T.alloc_buffer((1, ), "float32") + factor = T.alloc_buffer((1, ), "float32") + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + #max_kv_len: T.int32 = max_num_pages * 16 + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + + + for q_idx in T.serial(q_indptr[b_idx + 1] - q_indptr[b_idx]): + #init m, d, O + m_val[0] = -5e4 + d_val[0] = 1.0 + for d_idx in T.serial(d): + O_local[d_idx] = 0.0 + curl_q: T.int32 = q_indptr[b_idx] + q_idx + + for d_idx in T.serial(d): + + Q_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling), + q[curl_q, h_qo, d_idx] + ) + for row_idx in T.serial(max_num_pages * 16): + if row_idx < kv_chunk_len[0]: + # seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) + #seq_offset: T.int32(is_size_var=True) = row_idx + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // 16)] + page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16 + + # Load KV + for d_idx in T.serial(d): + K_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, page_offset, d_idx), dtype, rope_scaling), + pages[page_no, 0, h_qo // group_size, page_offset, d_idx] + ) + V_local[d_idx] = pages[page_no, 1, h_qo // group_size, page_offset, d_idx] + + # Compute S + # Q[i] * K[i] * attn_score * sm_scale + S_val[0] = 0.0 + for d_idx in T.serial(d): + S_val[0] += Q_local[d_idx] * K_local[d_idx] + S_val[0] *= attn_score_scaling_factor * sm_scale + + # update m_val, d_val , O_local + if _causal_mask(causal, + row=q_idx, + col=row_idx, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + new_m[0] = T.max(m_val[0], S_val[0]) + else: + S_val[0] = -5e4 + # update d_val + d_val[0] *= T.exp2(m_val[0] - new_m[0]) + d_val[0] += T.exp2(S_val[0] - new_m[0]) + + # restore O_local then update O_local + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + m_val[0] = new_m[0] + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] * scale_O[d_idx] + + + for d_idx in T.serial(d): + O_local[d_idx] += V_local[d_idx] * factor[0] + # Store Output + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] /d_val[0] + output[curl_q, h_qo, d_idx] = O_local[d_idx] + lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0]) + return batch_prefill_paged_kv_cpu + + def _attention_prefill( h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target ): @@ -920,6 +1291,189 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _attention_decode_cpu( + num_kv_heads, + num_qo_heads, + head_dim, + qkv_dtype, + sliding_window: bool, + rope_scaling: Dict[str, Any], +): + log2e = math.log2(math.exp(1)) + H_qo = num_qo_heads + H_kv = num_kv_heads + D = head_dim + group_size = num_qo_heads // num_kv_heads + + global_symbol = "batch_decode_paged_kv_cpu" + if sliding_window: + global_symbol += "_sliding_window" + + @T.prim_func(check_well_formed=False) + def batch_decode_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + Q_handle: T.handle, + pages_handle: T.handle, + page_table_indptr_handle: T.handle, + page_table_values_handle: T.handle, + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + k_rope_pos_offset_handle: T.handle, + q_rope_position_handle: T.handle, + output_handle: T.handle, + lse_handle: T.handle, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) + B = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) # query 值 + pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype) + page_table_indptr = T.match_buffer( + page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset + ) + page_table_values = T.match_buffer( + page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + q_rope_position = T.match_buffer( + q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset + ) + output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) + lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info( + var_length_info, B, sliding_window, length_info_elem_offset + ) + + sm_scale = 1.0 / math.sqrt(float(D)) * log2e + + for b in T.serial(B): + with T.block("attn"): + O_local = T.alloc_buffer((D,), "float32") + Q_local = T.alloc_buffer((D,), "float32") + K_local = T.alloc_buffer((D,), "float32") + V_local = T.alloc_buffer((D,), "float32") + + kv_chunk_len = T.alloc_buffer((1,), "int32") + + m_val = T.alloc_buffer((1,), "float32") + new_m = T.alloc_buffer((1,), "float32") + d_val = T.alloc_buffer((1,), "float32") + S_val = T.alloc_buffer((1,), "float32") + scale_O = T.alloc_buffer((1,), "float32") + factor = T.alloc_buffer((1,), "float32") + + cur_page_indptr_begin: T.int32 = page_table_indptr[b] + cur_page_indptr_end: T.int32 = page_table_indptr[b + 1] + + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len( + cur_page_indptr_end - cur_page_indptr_begin, + 16, + b, + length_info, + sliding_window, + ), + 0, + ) + + for h_qo in T.serial(H_qo): + m_val[0] = -5e4 + d_val[0] = 1.0 + + for d in T.serial(D): + O_local[d] = 0.0 + + for d in T.serial(D): + Q_local[d] = T.if_then_else( + rotary_mode == 1, + _rope( + Q, + q_rope_position[b], + head_dim, + rope_theta, + rope_scale, + (b, h_qo, d), + qkv_dtype, + rope_scaling, + ), + Q[b, h_qo, d], + ) + + for row_idx in T.serial(kv_chunk_len[0]): + seq_offset: T.int32(is_size_var=True) = _get_seq_offset( + row_idx, b, length_info, sliding_window + ) + page_no: T.int32(is_size_var=True) = page_table_values[ + cur_page_indptr_begin + (seq_offset // 16) + ] + page_offset: T.int32(is_size_var=True) = seq_offset % 16 + + for d in T.serial(D): + K_local[d] = T.if_then_else( + rotary_mode == 1, + _rope( + pages, + k_rope_pos_offset[b] + row_idx, + head_dim, + rope_theta, + rope_scale, + (page_no, 0, h_qo // group_size, page_offset, d), + qkv_dtype, + rope_scaling, + ), + pages[page_no, 0, h_qo // group_size, page_offset, d], + ) + S_val[0] = 0.0 + for d in T.serial(D): + S_val[0] += Q_local[d] * K_local[d] + S_val[0] *= attn_score_scaling_factor * sm_scale + + new_m[0] = T.max(m_val[0], S_val[0]) + d_val[0] = (d_val[0] * T.exp2(m_val[0] - new_m[0])) + T.exp2( + S_val[0] - new_m[0] + ) + + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + + for d in T.serial(D): + O_local[d] = O_local[d] * scale_O[0] + + m_val[0] = new_m[0] + for d in T.serial(D): + V_local[d] = pages[page_no, 1, h_qo // group_size, page_offset, d] + + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d in T.serial(D): + O_local[d] = O_local[d] + V_local[d] * factor[0] + for d in T.serial(D): + O_local[d] = O_local[d] / d_val[0] + output[b, h_qo, d] = O_local[d] + lse[b, h_qo] = m_val[0] + T.log2(d_val[0]) + + return batch_decode_paged_kv + + def _attention_decode( num_kv_heads, num_qo_heads, @@ -1179,6 +1733,47 @@ def batch_decode_paged_kv( return batch_decode_paged_kv +def _merge_state_inplace_cpu(v_dtype): + @T.prim_func + def merge_state_inplace_cpu( + v: T.handle, + s: T.handle, + v_other: T.handle, + s_other: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + N = T.int32(is_size_var=True) + H = T.int32(is_size_var=True) + D = T.int32(is_size_var=True) + + V = T.match_buffer(v, (N, H, D), v_dtype) + S = T.match_buffer(s, (N, H), "float32") + V_other = T.match_buffer(v_other, (N, H, D), v_dtype) + S_other = T.match_buffer(s_other, (N, H), "float32") + + for n in T.serial(N): + for h in T.serial(H): + with T.block("merge"): + s_val = _var_cpu("float32") + s_other_val = _var_cpu("float32") + s_max = _var_cpu("float32") + scale = _var_cpu("float32") + other_scale = _var_cpu("float32") + + s_val[0] = S[n, h] + s_other_val[0] = S_other[n, h] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + for d in T.serial(D): + V[n, h, d] = V[n, h, d] * scale[0] + V_other[n, h, d] * other_scale[0] + S[n, h] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + return merge_state_inplace_cpu + + def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target): v_dtype_bytes = 2 VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) @@ -1577,6 +2172,175 @@ def apply_schedule(sch): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + @T.prim_func + def batch_prefill_ragged_kv( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + batch_size = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer( + var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + for b in T.serial(batch_size): + with T.block("attn"): + softmax_sum = T.alloc_buffer([h_q], "float32") + m_prev = T.alloc_buffer([h_q], "float32") + m_new = T.alloc_buffer([h_q], "float32") + d_prev = T.alloc_buffer([h_q], "float32") + d_new = T.alloc_buffer([h_q], "float32") + p_sum = T.alloc_buffer([d], "float32") + max_score = T.alloc_buffer([h_q], "float32") + attention_scores = T.alloc_buffer([kv_len, h_q], "float32") + exp_scores = T.alloc_buffer([kv_len, h_q], "float32") + attention_score = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + query_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + key_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + result = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + + for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]): + for i in T.serial(h_q): + max_score[i] = -5e4 + m_prev[i] = -5e4 + d_prev[i] = 1.0 + + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + for h in T.serial(h_q): + h_kv_idx = h // group_size + + if _causal_mask( + causal, + row=q_idx, + col=k_idx, + kv_len=kv_indptr[b + 1] - kv_indptr[b], + qo_len=q_indptr[b + 1] - q_indptr[b], + ): + result[0] = 0.0 + for d_idx in T.serial(d): + query_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[q_indptr[b] + q_idx], + d, + rope_theta, + rope_scale, + (q_indptr[b] + q_idx, h, d_idx), + dtype, + rope_scaling, + ), + q[q_indptr[b] + q_idx, h, d_idx], + ) + + key_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + k, + k_rope_pos_offset[b] + k_idx, + d, + rope_theta, + rope_scale, + (kv_indptr[b] + k_idx, h_kv_idx, d_idx), + dtype, + rope_scaling, + ), + k[kv_indptr[b] + k_idx, h_kv_idx, d_idx], + ) + + result[0] += query_val[0] * key_val[0] + attention_score[0] = ( + result[0] * sm_scale * attn_score_scaling_factor + ) + else: + attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_scores[k_idx, h] = attention_score[0] + max_score[h] = T.max(max_score[h], attention_score[0]) + m_new[h] = T.max(m_prev[h], max_score[h]) + + for h in T.serial(h_q): + d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h]) + + for h in T.serial(h_q): + softmax_sum[h] = 0.0 + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h]) + softmax_sum[h] += exp_scores[k_idx, h] + d_new[h] += softmax_sum[h] + d_prev = d_new + m_prev = m_new + + for h in T.serial(h_q): + h_kv_idx = h // group_size + for i in T.serial(d): + p_sum[i] = 0.0 + for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + weight = exp_scores[v_idx, h] / d_new[h] + for i in T.serial(d): + p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight + for i in T.serial(d): + output[q_indptr[b] + q_idx, h, i] = p_sum[i] + lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h]) + + return batch_prefill_ragged_kv + + def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): # pylint: disable=line-too-long NUM_BLKS = 16 @@ -1949,6 +2713,45 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype): + tx = 1 + + @T.prim_func + def copy_single_page_cpu( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx): + for t in T.serial(tx): + with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + return copy_single_page_cpu + + def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): tx = get_max_num_threads_per_block(target) @@ -1996,6 +2799,55 @@ def copy_single_page( return copy_single_page +def _compact_kv_copy_cpu(num_heads, head_dim, dtype): + tx = 8 + + @T.prim_func + def compact_kv_copy_cpu( + var_pages: T.handle, + var_copy_length_indptr: T.handle, + var_copy_src_dst_pos: T.handle, + batch_size: T.int32, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + total_copy_length = T.int32() + copy_length_indptr_elem_offset = T.int32() + copy_src_dst_pos_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + copy_length_indptr = T.match_buffer( + var_copy_length_indptr, + (batch_size + 1,), + "int32", + elem_offset=copy_length_indptr_elem_offset, + ) + copy_src_dst_pos = T.match_buffer( + var_copy_src_dst_pos, + (2, total_copy_length), + "int32", + elem_offset=copy_src_dst_pos_elem_offset, + ) + + with T.block("root"): + for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) // tx): + for bhd_i in T.serial(tx): + b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) + h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads + d: T.int32 = (bhd_o * tx + bhd_i) % head_dim + if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ + src_pos // 16, 0, h, src_pos % 16, d + ] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ + src_pos // 16, 1, h, src_pos % 16, d + ] + + return compact_kv_copy_cpu + + def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): tx = get_max_num_threads_per_block(target) diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 9e4a7ed97e71..fa0146afb618 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -82,6 +82,213 @@ def _check_tree_order(tree_order_indptr, tree_order, batch, row, col, kv_len, qo ) +def _declare_length_info(var_length_info, batch_size, sliding_window, elem_offset): + return ( + T.match_buffer(var_length_info, (3, batch_size), "int32", elem_offset=elem_offset) + if sliding_window + else T.match_buffer(var_length_info, (batch_size,), "int32", elem_offset=elem_offset) + ) + + +def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + """Generate tree attention kernel for batched tree attention. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + # fmt: off + @T.prim_func + def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case + var_q_rope_position: T.handle, # [total_q_len] + var_mn_indptr: T.handle, # [batch_size + 1] + var_mask: T.handle, # [mn_indptr[batch_size]] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + batch_size: T.int32, + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + mn_indptr_elem_offset = T.int32(is_size_var=True) + mask_elem_offset = T.int32(is_size_var=True) + tree_size = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer( + var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + mn_indptr = T.match_buffer( + var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset + ) + mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + for b in T.serial(batch_size): + with T.block("attn"): + + softmax_sum = T.alloc_buffer([h_q], "float32") + m_prev = T.alloc_buffer([h_q], "float32") + m_new = T.alloc_buffer([h_q], "float32") + d_prev = T.alloc_buffer([h_q], "float32") + d_new = T.alloc_buffer([h_q], "float32") + p_sum = T.alloc_buffer([d], "float32") + + max_score = T.alloc_buffer([h_q], "float32") + attention_scores = T.alloc_buffer([kv_len, h_q], "float32") + exp_scores = T.alloc_buffer([kv_len, h_q], "float32") + attention_score = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + query_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + key_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + result = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + + for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]): + for i in T.serial(h_q): + max_score[i] = -5e4 + m_prev[i] = -5e4 + d_prev[i] = 1.0 + + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + for h in T.serial(h_q): + h_kv_idx = h // group_size + + if _check_tree_order( + row=q_idx, + col=k_idx, + batch=b, + tree_order=mask, + tree_order_indptr=mn_indptr, + kv_len=kv_indptr[b + 1] - kv_indptr[b], + qo_len=q_indptr[b + 1] - q_indptr[b], + ): + result[0] = 0.0 + for d_idx in T.serial(d): + query_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[q_indptr[b] + q_idx], + d, + rope_theta, + rope_scale, + (q_indptr[b] + q_idx, h, d_idx), + dtype, + rope_scaling, + ), + q[q_indptr[b] + q_idx, h, d_idx], + ) + + key_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + k, + q_rope_position[kv_indptr[b] + k_idx], + d, + rope_theta, + rope_scale, + (kv_indptr[b] + k_idx, h_kv_idx, d_idx), + dtype, + rope_scaling, + ), + k[kv_indptr[b] + k_idx, h_kv_idx, d_idx], + ) + + result[0] += query_val[0] * key_val[0] + attention_score[0] = ( + result[0] * sm_scale * attn_score_scaling_factor + ) + else: + attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_scores[k_idx, h] = attention_score[0] + max_score[h] = T.max(max_score[h], attention_score[0]) + m_new[h] = T.max(m_prev[h], max_score[h]) + + for h in T.serial(h_q): + d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h]) + + for h in T.serial(h_q): + softmax_sum[h] = 0.0 + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h]) + softmax_sum[h] += exp_scores[k_idx, h] + d_new[h] += softmax_sum[h] + d_prev = d_new + m_prev = m_new + + for h in T.serial(h_q): + h_kv_idx = h // group_size + for i in T.serial(d): + p_sum[i] = 0.0 + for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + weight = exp_scores[v_idx, h] / d_new[h] + for i in T.serial(d): + p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight + for i in T.serial(d): + output[q_indptr[b] + q_idx, h, i] = p_sum[i] + lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h]) + + # fmt: on + # pylint: enable=line-too-long,too-many-branches + return batch_tree_attn + + def tree_attn( h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target ): # pylint: disable=unused-argument @@ -437,6 +644,204 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + """Generate tree attention kernel for batched tree attention with paged key-value cache. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=import-outside-toplevel + from .kv_cache import ( + _declare_length_info, + _get_kv_chunk_len, + _get_seq_offset, + ) + + global_symbol = "tree_attn_paged_kv_cpu" + sliding_window = False + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func(check_well_formed=False) + def tree_attn_paged_kv_cpu( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + tree_order_indptr_handle: T.handle, # [batch_size + 1] + tree_order_handle: T.handle, # [total_len, 2] + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + tree_order_elem_offset = T.int32(is_size_var=True) + tree_order_indptr_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + tree_order_indptr = T.match_buffer( + tree_order_indptr_handle, + (batch_size + 1,), + "int32", + elem_offset=tree_order_indptr_elem_offset, + ) + total_tree_order_len = T.int32(is_size_var=True) + tree_order = T.match_buffer( + tree_order_handle, + (total_tree_order_len, 2), + "int32", + elem_offset=tree_order_elem_offset, + ) + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) + + + T.Assert( + rotary_mode == T.int32(0), "Inline rotary mode is not supported in tree attention." + ) + + for h_qo in T.serial(h_q): + for b_idx in T.serial(batch_size): + with T.block("attn"): + O_local = T.alloc_buffer((d, ), "float32") + Q_local = T.alloc_buffer((d, ), "float32") + K_local = T.alloc_buffer((d, ), "float32") + V_local = T.alloc_buffer((d, ), "float32") + + kv_chunk_len = T.alloc_buffer((1, ), "int32") + + m_val = T.alloc_buffer((1, ), "float32") + new_m = T.alloc_buffer((1, ), "float32") + d_val = T.alloc_buffer((1, ), "float32") + S_val = T.alloc_buffer((1, ), "float32") + scale_O = T.alloc_buffer((1, ), "float32") + factor = T.alloc_buffer((1, ), "float32") + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + + for q_idx in T.serial(q_indptr[b_idx + 1] - q_indptr[b_idx]): + #init m, d, O + m_val[0] = -5e4 + d_val[0] = 1.0 + for d_idx in T.serial(d): + O_local[d_idx] = 0.0 + curl_q: T.int32 = q_indptr[b_idx] + q_idx + + for d_idx in T.serial(d): + Q_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling), + q[curl_q, h_qo, d_idx] + ) + for row_idx in T.serial(max_num_pages * 16): + if row_idx < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // 16)] + page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16 + + # Load KV + for d_idx in T.serial(d): + K_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, page_offset, d_idx), dtype, rope_scaling), + pages[page_no, 0, h_qo // group_size, page_offset, d_idx] + ) + V_local[d_idx] = pages[page_no, 1, h_qo // group_size, page_offset, d_idx] + + # Compute S + S_val[0] = 0.0 + for d_idx in T.serial(d): + S_val[0] += Q_local[d_idx] * K_local[d_idx] + S_val[0] *= attn_score_scaling_factor * sm_scale + + # update m_val, d_val , O_local + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=q_idx, + col=row_idx, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], + ): + new_m[0] = T.max(m_val[0], S_val[0]) + else: + S_val[0] = -5e4 + # update d_val + d_val[0] *= T.exp2(m_val[0] - new_m[0]) + d_val[0] += T.exp2(S_val[0] - new_m[0]) + + # restore O_local then update O_local + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + m_val[0] = new_m[0] + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] * scale_O[d_idx] + + + for d_idx in T.serial(d): + O_local[d_idx] += V_local[d_idx] * factor[0] + # Store Output + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] /d_val[0] + output[curl_q, h_qo, d_idx] = O_local[d_idx] + lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0]) + return tree_attn_paged_kv_cpu + + def tree_attn_with_paged_kv_cache( h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target ): diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index ccd726a6ece6..5dd470f00d79 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -34,6 +34,19 @@ #include #endif +#if defined(__linux__) || defined(__ANDROID__) +#include +#endif + +#ifdef _WIN32 +#include +#endif + +#if defined(__APPLE__) +#include +#include +#endif + namespace tvm { namespace runtime { class CPUDeviceAPI final : public DeviceAPI { @@ -43,6 +56,41 @@ class CPUDeviceAPI final : public DeviceAPI { if (kind == kExist) { *rv = 1; } + + switch (kind) { + case kExist: + break; + case kTotalGlobalMemory: { +#if defined(__linux__) || defined(__ANDROID__) + struct sysinfo info; + if (sysinfo(&info) == 0) { + *rv = static_cast(info.totalram) * info.mem_unit; // Convert to bytes + } else { + *rv = -1; + } +#elif defined(_WIN32) + MEMORYSTATUSEX statex; + statex.dwLength = sizeof(statex); + if (GlobalMemoryStatusEx(&statex)) { + *rv = static_cast(statex.ullTotalPhys); // Total physical memory in bytes + } else { + *rv = -1; + } +#elif defined(__APPLE__) + int64_t mem; + size_t size = sizeof(mem); + if (sysctlbyname("hw.memsize", &mem, &size, nullptr, 0) == 0) { + *rv = mem; + } else { + *rv = -1; + } +#else + *rv = -1; +#endif + } + default: + break; + } } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { void* ptr; diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 81c55bfcb645..963a5dc5c16a 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1302,7 +1302,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Create the auxiliary data manager for attention. // We only use the merged aux data for CUDA, since direct pointer // operations may have issues on other platforms. - if (device_.device_type == DLDeviceType::kDLCUDA) { + if (device_.device_type == DLDeviceType::kDLCUDA || + device_.device_type == DLDeviceType::kDLCPU) { aux_data_manager_ = std::make_unique( reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, preferred_host_device, copy_stream_); diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py new file mode 100644 index 000000000000..9487bbf8601a --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -0,0 +1,956 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum +import itertools +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pytest +import scipy.special + +import tvm +import tvm.testing +from tvm import dlight as dl +from tvm.relax.frontend.nn.llm.kv_cache import ( + _attention_decode_cpu, + _attention_prefill_cpu, + _attention_prefill_ragged_cpu, + _compact_kv_copy_cpu, + _copy_single_page_cpu, + _kv_cache_debug_get_kv, + _kv_cache_transpose_append, + _merge_state_inplace_cpu, + llama_rope_with_position_map, + tree_attn_cpu, + tree_attn_with_paged_kv_cache_cpu, +) +from tvm.runtime import ShapeTuple + +reserved_nseq = 32 +maximum_total_seq_length = 2048 +prefill_chunk_size = 512 +page_size = 16 +num_layers = 4 +num_qo_heads = 32 +num_kv_heads = 4 +head_dim = None +rope_scale = 1.0 +rope_theta = 1e4 +rope_scaling = {} +dtype = None +device = tvm.cpu() + +fclear = None +fadd_sequence = None +fremove_sequence = None +ffork_sequence = None +fenable_sliding_window_for_seq = None +fpopn = None +fbegin_forward = None +fend_forward = None +fcommit_accepted_token_tree_nodes = None +fattention_with_fuse_qkv = None +fis_empty = None +fdebug_get_kv = None + +ftranspose_append = None +fcopy_cache = None +fattn_prefill = None +fattn_decode = None +fattn_prefill_sliding_window = None +fattn_decode_sliding_window = None +fattn_prefill_ragged = None +fattn_prefill_with_tree_mask = None +fattn_prefill_with_tree_mask_paged_kv_cache = None +fmerge_state = None +fsplit_rotary = None +fattention_rotary = None +fcopy_single_page = None +fcompact_copy = None + + +def set_global_func(head_dim, dtype): + global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq + global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes + global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv + global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode + global fattn_prefill_ragged, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache + global fattn_prefill_sliding_window, fattn_decode_sliding_window + global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy + + fclear = tvm.get_global_func("vm.builtin.kv_state_clear") + fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") + fremove_sequence = tvm.get_global_func("vm.builtin.kv_state_remove_sequence") + ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence") + fenable_sliding_window_for_seq = tvm.get_global_func( + "vm.builtin.attention_kv_cache_enable_sliding_window_for_seq" + ) + fpopn = tvm.get_global_func("vm.builtin.kv_state_popn") + fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") + fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") + fcommit_accepted_token_tree_nodes = tvm.get_global_func( + "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes" + ) + fattention_with_fuse_qkv = tvm.get_global_func( + "vm.builtin.attention_kv_cache_attention_with_fused_qkv" + ) + fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") + fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") + + target = tvm.target.Target.from_device(device) + builts = [] + for tir_func in [ + _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), + _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype), + _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling), + _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling), + _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling), + _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling), + _attention_prefill_ragged_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling), + tree_attn_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling), + tree_attn_with_paged_kv_cache_cpu( + num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling + ), + _merge_state_inplace_cpu(dtype), + llama_rope_with_position_map( + rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, rope_scaling + ), + _copy_single_page_cpu(num_kv_heads, page_size, head_dim, dtype), + _compact_kv_copy_cpu(num_kv_heads, head_dim, dtype), + ]: + mod = tvm.IRModule({"main": tir_func}) + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) + f = tvm.build(mod["main"], target=target) + builts.append(f.entry_func) + + ( + ftranspose_append, + fcopy_cache, + fattn_prefill, + fattn_decode, + fattn_prefill_sliding_window, + fattn_decode_sliding_window, + fattn_prefill_ragged, + fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, + fmerge_state, + fsplit_rotary, + fcopy_single_page, + fcompact_copy, + ) = builts + + +def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced") + cache = fcreate( + tvm.runtime.ShapeTuple( + [ + reserved_nseq, + maximum_total_seq_length, + prefill_chunk_size, + page_size, + int(support_sliding_window), + ] + ), + tvm.runtime.ShapeTuple([0, num_layers]), + num_qo_heads, + num_kv_heads, + head_dim, + rope_mode, + rope_scale, + rope_theta, + tvm.nd.empty((), dtype, device=device), + ftranspose_append, + fattn_prefill, + fattn_decode, + fattn_prefill_sliding_window, + fattn_decode_sliding_window, + fattn_prefill_ragged, + fmerge_state, + fsplit_rotary, + fcopy_single_page, + fcopy_cache, + fcompact_copy, + fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, + None, + False, + ) + return cache + + +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +@pytest.fixture( + params=itertools.chain( + itertools.product( + [64, 128], + ["float32", "float16"], + [RopeMode.NORMAL], + [False], + ), + itertools.product( + [128], + ["float16"], + [RopeMode.NONE, RopeMode.INLINE], + [False, True], + ), + ) +) +def kv_cache_and_config(request): + global head_dim, dtype + head_dim, dtype, rope_mode, support_sliding_window = request.param + set_global_func(head_dim, dtype) + return create_kv_cache(*request.param), rope_mode, support_sliding_window + + +def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): + for seq_id in seq_ids: + keys_expected = expected_k[seq_id] + values_expected = expected_v[seq_id] + assert keys_expected.shape == values_expected.shape + seq_length = expected_k[seq_id].shape[1] + keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) + values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) + fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) + tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) + + +def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = None): + # x: (N, H, D) + assert len(x.shape) == 3 + nfeat = x.shape[-1] + nfeat_half = x.shape[-1] // 2 + x = x.astype("float32") + y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1) + + inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / nfeat)) + t = ( + np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) + if offset_list is None + else (np.array(offset_list, dtype=inv_freq.dtype) + offset) + ) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + cos_values = np.cos(emb) + sin_values = np.sin(emb) + + return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj", sin_values, y) + + +def apply_attention( + kv_cache, + rope_mode: RopeMode, + batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], + cached_k: Dict[int, np.ndarray], + cached_v: Dict[int, np.ndarray], + sliding_window_sizes: Optional[List[int]] = None, + attn_sink_sizes: Optional[List[int]] = None, + token_tree_parent_ptr_list: Optional[List[List[int]]] = None, + accepted_leaf_indices: Optional[List[int]] = None, +) -> None: + seq_ids = [] + append_lengths = [] + for i, (seq_id, append_length) in enumerate(batch): + fork_parent_id = None + if isinstance(seq_id, tuple): + # Fork sequence + seq_id, fork_parent_id, fork_pos = seq_id + batch[i] = (seq_id, append_length) + seq_ids.append(seq_id) + append_lengths.append(append_length) + if fork_parent_id is not None: + assert fork_parent_id in cached_k + assert seq_id not in cached_k + ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos) + if fork_pos == -1: + cached_k[seq_id] = cached_k[fork_parent_id] + cached_v[seq_id] = cached_v[fork_parent_id] + else: + cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos] + cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos] + elif seq_id not in cached_k: + fadd_sequence(kv_cache, seq_id) + cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + + flattened_token_tree_parent_ptr = None + token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] + + if token_tree_parent_ptr_list: + assert len(token_tree_node_depths_list) == len(seq_ids) + if accepted_leaf_indices is not None: + assert len(accepted_leaf_indices) == len(seq_ids) + flattened_token_tree_parent_ptr = [] + for i, (token_tree_parent_ptr, append_length) in enumerate( + zip(token_tree_parent_ptr_list, append_lengths) + ): + assert len(token_tree_parent_ptr) >= append_length + # parent pointer for the last `append_length` nodes (the new tokens) + append_token_tree_parent_ptr = token_tree_parent_ptr[-append_length:] + flattened_token_tree_parent_ptr += append_token_tree_parent_ptr + token_tree_node_depths = [] + for parent in token_tree_parent_ptr: + token_tree_node_depths.append( + 0 if parent == -1 else token_tree_node_depths[parent] + 1 + ) + # depth of each node in the tree (this contains more than the last `append_length` nodes) + token_tree_node_depths_list[i] = token_tree_node_depths + + fbegin_forward( + kv_cache, + ShapeTuple(seq_ids), + ShapeTuple(append_lengths), + ( + ShapeTuple(flattened_token_tree_parent_ptr) + if flattened_token_tree_parent_ptr is not None + else None + ), + ) + + global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype) + global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + + q_array = [] + for i, (seq_id, append_length) in enumerate(batch): + new_q = np.random.rand(num_layers, append_length, num_qo_heads, head_dim).astype(dtype) + new_k = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) + new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) + q_array.append(new_q) + + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length + assert prev_tree_size >= 0 + rope_offset -= prev_tree_size + cached_k[seq_id] = np.concatenate( + [ + cached_k[seq_id], + np.stack( + [ + ( + new_k[l] + if rope_mode != RopeMode.NORMAL + else f_apply_rotary( + new_k[l], + rope_offset, + rope_scale, + rope_theta, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), + ) + ) + for l in range(num_layers) + ], + axis=0, + ), + ], + axis=1, + ) + cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1) + global_new_q = np.concatenate([global_new_q, new_q], axis=1) + global_new_k = np.concatenate([global_new_k, new_k], axis=1) + global_new_v = np.concatenate([global_new_v, new_v], axis=1) + + for layer_id in range(num_layers): + queries_np = global_new_q[layer_id] + keys_np = global_new_k[layer_id] + values_np = global_new_v[layer_id] + qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) + outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) + fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs) + + # Compute attention expected results. + outputs = np.expand_dims(outputs.numpy(), axis=0) + sum_length = 0 + for i, (seq_id, append_length) in enumerate(batch): + assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length + + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + rope_offset -= len(token_tree_parent_ptr_list[i]) + else: + rope_offset -= append_length + q_seq = ( + q_array[i][layer_id] + if rope_mode == RopeMode.NONE + else f_apply_rotary( + q_array[i][layer_id], + rope_offset, + rope_scale, + rope_theta, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), + ) + ).transpose(1, 0, 2) + k_seq = ( + cached_k[seq_id][layer_id] + if rope_mode != RopeMode.INLINE + else f_apply_rotary( + cached_k[seq_id][layer_id], + 0, + rope_scale, + rope_theta, + ( + ( + list(range(rope_offset)) + + [depth + rope_offset for depth in token_tree_node_depths_list[i]] + ) + if token_tree_node_depths_list[i] is not None + else None + ), + ) + ).transpose(1, 2, 0) + v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) + + k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) + v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0) + softmax_input = (q_seq.astype("float32") @ k_seq.astype("float32")) / np.sqrt(head_dim) + softmax_shape = softmax_input.shape + assert softmax_shape[-2] == append_length + length_diff = softmax_shape[-1] - softmax_shape[-2] + assert length_diff >= 0 + mask = np.tril( + np.full_like(softmax_input, np.finfo("float32").max), k=length_diff + ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) + if token_tree_parent_ptr_list is not None: + tree_size = len(token_tree_parent_ptr_list[i]) + tree_mask = np.full( + (tree_size, tree_size), np.finfo("float32").min, dtype="float32" + ) + for i, parent in enumerate(token_tree_parent_ptr_list[i]): + if parent != -1: + tree_mask[i] = tree_mask[parent] + tree_mask[i, i] = np.finfo("float32").max + tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) + mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :] + + softmax_input = np.minimum(softmax_input, mask) + + results = np.expand_dims( + (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( + 1, 0, 2 + ), + axis=0, + ).astype(dtype) + + tvm.testing.assert_allclose( + outputs[:, sum_length : sum_length + append_length, ...], + results, + rtol=1e-3, + atol=1e-3, + ) + sum_length += append_length + fend_forward(kv_cache) + + if accepted_leaf_indices is not None: + seq_ids = [seq_id for seq_id, _ in batch] + fcommit_accepted_token_tree_nodes( + kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices) + ) + for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate( + zip(accepted_leaf_indices, batch) + ): + tree_path = [] + node = accepted_leaf_idx + while node != -1: + tree_path.append(node) + node = token_tree_parent_ptr_list[i][node] + offset = cached_k[seq_id].shape[1] - append_length + length_to_pop = append_length - len(tree_path) + assert 0 <= length_to_pop <= append_length + for dst_pos, src_pos in enumerate(reversed(tree_path)): + if dst_pos == src_pos: + continue + cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][ + :, offset + src_pos, ... + ] + cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][ + :, offset + src_pos, ... + ] + if length_to_pop > 0: + cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...] + cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...] + + for seq_id, _ in batch: + if sliding_window_sizes is not None and len(sliding_window_sizes) > seq_id: + assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) > seq_id + sliding_window_size = sliding_window_sizes[seq_id] + attn_sink_size = attn_sink_sizes[seq_id] + if sliding_window_size == 0: + continue + if cached_k[seq_id].shape[1] > sliding_window_size: + # Apply sliding window and sink to cached kv. + length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size + cached_k[seq_id] = np.concatenate( + [ + cached_k[seq_id][:, :attn_sink_size, ...], + cached_k[seq_id][:, attn_sink_size + length_to_slide :, ...], + ], + axis=1, + ) + cached_v[seq_id] = np.concatenate( + [ + cached_v[seq_id][:, :attn_sink_size, ...], + cached_v[seq_id][:, attn_sink_size + length_to_slide :, ...], + ], + axis=1, + ) + assert cached_k[seq_id].shape[1] == sliding_window_size + + # Verify + verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v) + + +def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + # Normal RoPE mode under sliding window settings is not supported. + return + fclear(kv_cache) + + # Prefill. + operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] + operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] + # Decode + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]] + operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + + cached_k = {} + cached_v = {} + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + +def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + # Normal RoPE mode under sliding window settings is not supported. + return + fclear(kv_cache) + + num_sequences = 5 + batch = [(seq_id, 1) for seq_id in range(num_sequences)] + cached_k = {} + cached_v = {} + for seq_id_to_remove in range(num_sequences): + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + # Remove sequence. + fremove_sequence(kv_cache, seq_id_to_remove) + cached_k.pop(seq_id_to_remove) + cached_v.pop(seq_id_to_remove) + verify_cached_kv( + kv_cache, + seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id != seq_id_to_remove], + expected_k=cached_k, + expected_v=cached_v, + ) + + +def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + # Normal RoPE mode under sliding window settings is not supported. + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + batch = [(0, 60), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + # Fork existing sequences. + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)], cached_k, cached_v) + # 0 <- 5 <- 6,8,9 + # 0 <- 7 + # 3 <- 4 + # Mixture of decode and prefill. + operation_seq = [ + [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)], + [(7, 10), (6, 2), (8, 3), (9, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15), 14)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17), 19)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80), 10)], cached_k, cached_v) + apply_attention( + kv_cache, + rope_mode, + [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)], + cached_k, + cached_v, + ) + + operation_seq = [ + [(6, 1), (11, 1), (13, 1), (9, 1)], + [(10, 1), (16, 1), (18, 1), (19, 1)], + [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)], + [(10, 10), (6, 2), (8, 3), (19, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + num_sequence = 20 + for i in range(num_sequence): + fremove_sequence(kv_cache, i) + cached_k.pop(i) + cached_v.pop(i) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + # Test fork after page recycle + apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) + + +def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + # Normal RoPE mode under sliding window settings is not supported. + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v) + # Fork existing sequences. + apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k, cached_v) + # 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11 + # | | | | + # 4 6 8 10 + # Decode. + operation_seq = [ + [(3, 1), (6, 1), (9, 1)], + [(4, 1), (8, 1), (10, 1)], + [(5, 1), (7, 1), (11, 1)], + ] + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + num_sequence = 12 + for i in range(num_sequence): + fremove_sequence(kv_cache, i) + cached_k.pop(i) + cached_v.pop(i) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + +def test_paged_attention_kv_cache_popn(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + batch = [(0, 35), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) + + popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)] + for seq_id, pop_length in popn_operations: + fpopn(kv_cache, seq_id, pop_length) + if pop_length != 0: + cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...] + cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...] + verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v) + + num_sequence = 5 + for seq_id in range(num_sequence): + fremove_sequence(kv_cache, seq_id) + verify_cached_kv( + kv_cache, + seq_ids=list(range(seq_id + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + +def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if not support_sliding_window or rope_mode == RopeMode.NORMAL: + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + sliding_window_sizes = [20, 25, 30, 35, 40] + attn_sink_sizes = [6, 4, 8, 3, 7] + for seq_id, (sliding_window_size, attn_sink_size) in enumerate( + zip(sliding_window_sizes, attn_sink_sizes) + ): + fadd_sequence(kv_cache, seq_id) + fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size) + cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + + # Prefill. + operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]] + operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]] + operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]] + for batch in operation_seq: + apply_attention( + kv_cache, + rope_mode, + batch, + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # Decode + batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)] + for _ in range(20): + apply_attention( + kv_cache, + rope_mode, + batch, + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + + +def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if not support_sliding_window or rope_mode == RopeMode.NORMAL: + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + sliding_window_sizes = [30, 35, 40] + attn_sink_sizes = [15, 20, 25] + for seq_id, (sliding_window_size, attn_sink_size) in enumerate( + zip(sliding_window_sizes, attn_sink_sizes) + ): + fadd_sequence(kv_cache, seq_id) + fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size) + cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + apply_attention( + kv_cache, + rope_mode, + [(0, 12), (1, 18), (2, 28)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [12, 18, 25+3] + sliding_window_sizes += [0, 0, 0] + attn_sink_sizes += [0, 0, 0] + apply_attention( + kv_cache, + rope_mode, + [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [12, 18, 25+3, 18, 38, 36] + apply_attention( + kv_cache, + rope_mode, + [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [15+6, 20+13, 25+7, 28, 41, 43] + sliding_window_sizes += [25] + attn_sink_sizes += [24] + ffork_sequence(kv_cache, 3, 6, 18) + fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], attn_sink_sizes[-1]) + cached_k[6] = cached_k[3][::, :18] + cached_v[6] = cached_v[3][::, :18] + apply_attention( + kv_cache, + rope_mode, + [(3, 10), (6, 12)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6] + + +def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window: + # Normal RoPE mode under sliding window settings is not supported. + return + if rope_mode == RopeMode.INLINE: + # Inline RoPE mode is not supported for tree attention. + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Tree attention + apply_attention( + kv_cache, + rope_mode, + [(0, 7), (1, 15), (2, 10), (3, 14)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[ + [-1, 0, 0, 1, 1, 2, 2], # complete binary tree of height 3 + [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], # complete binary tree of height 4 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10 + [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9], # two complete binary trees of height 3 + ], + accepted_leaf_indices=[6, 11, 6, 13], + ) + # Do 5 rounds of decode. + for _ in range(5): + apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + + # Test the cases where all trees are chains. + fclear(kv_cache) + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Tree attention + apply_attention( + kv_cache, + rope_mode, + [(0, 7), (1, 15), (2, 10), (3, 14)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[ + [-1, 0, 1, 2, 3, 4, 5], # complete binary tree of height 7 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], # chain of length 15 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length 14 + ], + accepted_leaf_indices=[2, 6, -1, 4], + ) + # Do 5 rounds of decode. + for _ in range(5): + apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + + # Test the cases of tree attn with cached kv. + fclear(kv_cache) + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Do 5 rounds of tree decode. + num_seq = 4 + for i in range(5): + num_leaf_nodes = 2**i + parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)] + apply_attention( + kv_cache, + rope_mode, + [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)], + accepted_leaf_indices=( + None if i != 4 else [2, 6, -1, 4] + ), # Leaf nodes are committed all at once at the end. + ) + + +if __name__ == "__main__": + HEAD_DIMS = [64, 128] + DTYPES = ["float16", "float32"] + ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE] + SUPPORT_SLIDING_WINDOW = [False, True] + for head_dim, dtype, rope_mode, support_sliding_window in itertools.product( + HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW + ): + set_global_func(head_dim, dtype) + cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window) + cache_and_config = (cache, rope_mode, support_sliding_window) + test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) + test_paged_attention_kv_cache_remove_sequence(cache_and_config) + test_paged_attention_kv_cache_fork_sequence(cache_and_config) + test_paged_attention_kv_cache_popn(cache_and_config) + test_paged_attention_kv_cache_sliding_window(cache_and_config) + test_paged_attention_kv_cache_tree_attn(cache_and_config) + test_paged_attention_kv_cache_unlimited_depth(cache_and_config)