diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index c45e7546fac1..308bc8be1dec 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -40,18 +40,36 @@ def kda_attention( - hidden_states: torch.Tensor, - output: torch.Tensor, + q_proj_states: torch.Tensor, + k_proj_states: torch.Tensor, + v_proj_states: torch.Tensor, + g1: torch.Tensor, + g2: torch.Tensor, + beta: torch.Tensor, + core_attn_out: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self._forward(hidden_states=hidden_states, output=output) + self._forward( + q_proj_states=q_proj_states, + k_proj_states=k_proj_states, + v_proj_states=v_proj_states, + g1=g1, + g2=g2, + beta=beta, + core_attn_out=core_attn_out, + ) def kda_attention_fake( - hidden_states: torch.Tensor, - output: torch.Tensor, + q_proj_states: torch.Tensor, + k_proj_states: torch.Tensor, + v_proj_states: torch.Tensor, + g1: torch.Tensor, + g2: torch.Tensor, + beta: torch.Tensor, + core_attn_out: torch.Tensor, layer_name: str, ) -> None: return @@ -60,7 +78,7 @@ def kda_attention_fake( direct_register_custom_op( op_name="kda_attention", op_func=kda_attention, - mutates_args=["output"], + mutates_args=["core_attn_out"], fake_impl=kda_attention_fake, ) @@ -241,37 +259,56 @@ def forward( hidden_states: torch.Tensor, positions: torch.Tensor, output: torch.Tensor, - ) -> None: - return torch.ops.vllm.kda_attention( - hidden_states, - output, + ) -> torch.Tensor: + num_tokens = hidden_states.size(0) + q = self.q_proj(hidden_states)[0] + k = self.k_proj(hidden_states)[0] + v = self.v_proj(hidden_states)[0] + + beta = self.b_proj(hidden_states)[0].float().sigmoid() + g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] + g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias) + beta = beta.unsqueeze(0) + g1 = g1.unsqueeze(0) + + g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] + g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) + + core_attn_out = torch.zeros( + (1, num_tokens, self.local_num_heads, self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops.vllm.kda_attention( + q, + k, + v, + g1, + g2, + beta, + core_attn_out, self.prefix, ) + core_attn_out = self.o_norm(core_attn_out, g2) + core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") + + return self.o_proj(core_attn_out)[0] def _forward( self, - hidden_states: torch.Tensor, - output: torch.Tensor, + q_proj_states: torch.Tensor, + k_proj_states: torch.Tensor, + v_proj_states: torch.Tensor, + g1: torch.Tensor, + g2: torch.Tensor, + beta: torch.Tensor, + core_attn_out: torch.Tensor, ) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # V1 profile run - # Mimic the memory allocation in the real run - q = torch.empty_like(hidden_states) - k = torch.empty_like(hidden_states) - v = torch.empty_like(hidden_states) - g = hidden_states.new_empty( - hidden_states.size(0), - self.local_num_heads, - self.head_dim, - dtype=torch.float32, - ) - beta = torch.empty( - hidden_states.size(0), self.local_num_heads, dtype=torch.float32 - ) - core_attn_out = torch.empty_like(hidden_states) + # # V1 profile run return assert isinstance(attn_metadata, dict) @@ -288,10 +325,6 @@ def _forward( conv_state_k = conv_state_k.transpose(-1, -2) conv_state_v = conv_state_v.transpose(-1, -2) - q_proj_states = self.q_proj(hidden_states)[0] - k_proj_states = self.k_proj(hidden_states)[0] - v_proj_states = self.v_proj(hidden_states)[0] - q_conv_weights = self.q_conv1d.weight.view( self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) ) @@ -374,14 +407,6 @@ def _forward( lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) ) - beta = self.b_proj(hidden_states)[0].float().sigmoid() - - g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] - g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) - - beta = beta.unsqueeze(0) - g = g.unsqueeze(0) - if attn_metadata.num_prefills > 0: zero_idx = non_spec_state_indices_tensor[~has_initial_state] recurrent_state[zero_idx] = 0 @@ -393,7 +418,7 @@ def _forward( q=q, k=k, v=v, - g=g, + g=g1, beta=beta, initial_state=initial_state, output_final_state=True, @@ -410,17 +435,12 @@ def _forward( q=q, k=k, v=v, - g=g, + g=g1, beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, cu_seqlens=non_spec_query_start_loc, ssm_state_indices=non_spec_state_indices_tensor, ) - - g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] - g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) - core_attn_out = self.o_norm(core_attn_out_non_spec, g) - core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") - - output[:] = self.o_proj(core_attn_out)[0] + assert core_attn_out_non_spec.shape == core_attn_out.shape + core_attn_out[:] = core_attn_out_non_spec