Skip to content

Commit

Permalink
Fix the residual connection of the chatglm model architecture (#76)
Browse files Browse the repository at this point in the history
* fix the chatglm model architecture

* remove try catch
  • Loading branch information
alfredgui2 authored Aug 29, 2024
1 parent cf69455 commit 7209d6f
Showing 1 changed file with 10 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def forward(
q, k, v = self.flashinferWrapper.reshape_qkv_for_attention(
q, k, v, batch_position
)

rotate_query_key_in_place(q, k, cos, sin, is_neox=False)
attn_outputs_raw = self.flashinferWrapper.computeAttention(
q,
Expand Down Expand Up @@ -338,15 +339,14 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
kvCachePool: KvCachePool,
is_prefill: bool,
batch_position: KvCacheBatchPosition,
cos: torch.Tensor,
sin: torch.Tensor,
loraWeight: BatchedModelLoraWeight | None,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
normed_hidden_states, _ = self.input_layernorm(hidden_states)

attn_output = self.self_attn(
normed_hidden_states,
Expand All @@ -358,13 +358,13 @@ def forward(
loraWeight,
)

normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)

mlp_output = self.mlp(normed_attn_res_output, loraWeight)
residual = hidden_states
layernorm_input = residual + attn_output
normed_attn_res_output, _ = self.post_attention_layernorm(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(normed_attn_res_output, loraWeight) + residual

return mlp_output, attn_res
return mlp_output


class FlashChatGLM3Model(torch.nn.Module):
Expand Down Expand Up @@ -428,8 +428,6 @@ def forward(
loraWeight: BatchedModelLoraWeight | None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None

self.flashinferWrapper.prepareAttention(
is_prefill,
batch_position,
Expand All @@ -446,9 +444,8 @@ def forward(
is_prefill,
)
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states = layer(
hidden_states,
residual,
kvCachePool,
is_prefill,
batch_position,
Expand All @@ -457,7 +454,7 @@ def forward(
loraWeight,
)

hidden_states, _ = self.norm(hidden_states, residual)
hidden_states, _ = self.norm(hidden_states)
self.flashinferWrapper.endBatchAttention(is_prefill)
return hidden_states

Expand Down

0 comments on commit 7209d6f

Please sign in to comment.