From 7209d6f5fb519572aa07269815dd0f1850d9c7ad Mon Sep 17 00:00:00 2001 From: Alfred Gui Date: Wed, 28 Aug 2024 21:53:00 -0400 Subject: [PATCH] Fix the residual connection of the chatglm model architecture (#76) * fix the chatglm model architecture * remove try catch --- .../flashinfer_chatglm_modeling.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_chatglm_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_chatglm_modeling.py index 44fa9ffa..367d38b3 100644 --- a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_chatglm_modeling.py +++ b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_chatglm_modeling.py @@ -288,6 +288,7 @@ def forward( q, k, v = self.flashinferWrapper.reshape_qkv_for_attention( q, k, v, batch_position ) + rotate_query_key_in_place(q, k, cos, sin, is_neox=False) attn_outputs_raw = self.flashinferWrapper.computeAttention( q, @@ -338,7 +339,6 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor, kvCachePool: KvCachePool, is_prefill: bool, batch_position: KvCacheBatchPosition, @@ -346,7 +346,7 @@ def forward( sin: torch.Tensor, loraWeight: BatchedModelLoraWeight | None, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + normed_hidden_states, _ = self.input_layernorm(hidden_states) attn_output = self.self_attn( normed_hidden_states, @@ -358,13 +358,13 @@ def forward( loraWeight, ) - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) - - mlp_output = self.mlp(normed_attn_res_output, loraWeight) + residual = hidden_states + layernorm_input = residual + attn_output + normed_attn_res_output, _ = self.post_attention_layernorm(layernorm_input) + residual = layernorm_input + mlp_output = self.mlp(normed_attn_res_output, loraWeight) + residual - return mlp_output, attn_res + return mlp_output class FlashChatGLM3Model(torch.nn.Module): @@ -428,8 +428,6 @@ def forward( loraWeight: BatchedModelLoraWeight | None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - residual = None - self.flashinferWrapper.prepareAttention( is_prefill, batch_position, @@ -446,9 +444,8 @@ def forward( is_prefill, ) for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, - residual, kvCachePool, is_prefill, batch_position, @@ -457,7 +454,7 @@ def forward( loraWeight, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states) self.flashinferWrapper.endBatchAttention(is_prefill) return hidden_states