Skip to content

Commit

Permalink
use concat & reduce sum
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita committed Sep 18, 2024
1 parent 56bb432 commit 643d16b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
36 changes: 28 additions & 8 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="N
self.n_splits_linear=n_splits_linear
self.n_splits_down_proj=n_splits_down_proj

def reduce_linear(self, to_concat):
concat = self.sequence_concat(to_concat, axis=0)
return self.reduce_sum(concat, reduction_axes=0, keep_dims=True)

def attention(self,
*,
hidden_states,
Expand Down Expand Up @@ -188,9 +192,14 @@ def attention(self,
wt_dtype=self.dtype,
)
)
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
if mode == "decode":
query_states = self.reduce_linear(query_states_to_concat)
key_states = self.reduce_linear(key_states_to_concat)
value_states = self.reduce_linear(value_states_to_concat)
else:
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
if q_bias is not None:
query_states = query_states + q_bias
if k_bias is not None:
Expand Down Expand Up @@ -277,12 +286,16 @@ def attention(self,
sub_attn_output, hidden_size, groupsize, bias=False, wt_dtype=self.dtype
)
)
attn_output = sum(attn_output_to_concat)
if mode == "decode":
attn_output = self.reduce_linear(attn_output_to_concat)
else:
attn_output = sum(attn_output_to_concat)

return attn_output, new_key_states, new_value_states

def mlp(self, hidden_states, seq_len=-1):
def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
print(f"mp_models_base mlp")
use_concat_reduce = (mode == "decode" and False)
if self.n_splits_linear == 1:
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
Expand All @@ -309,8 +322,12 @@ def mlp(self, hidden_states, seq_len=-1):
sub_hidden_states, self.intermediate_size, gate_up_groupsize, bias=False, wt_dtype=self.dtype
)
)
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
if use_concat_reduce:
mm1 = self.reduce_linear(mm1_to_concat)
mm2 = self.reduce_linear(mm2_to_concat)
else:
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]

if self.n_splits_down_proj == 1:
Expand All @@ -332,7 +349,10 @@ def mlp(self, hidden_states, seq_len=-1):
# print(hidden_states_to_concat[0].shape)
# hidden_states = self.concat_list(hidden_states_to_concat, 0)
# hidden_states = self.reduce_sum(hidden_states, 0)
hidden_states = sum(hidden_states_to_concat)
if use_concat_reduce:
hidden_states = self.reduce_linear(hidden_states_to_concat)
else:
hidden_states = sum(hidden_states_to_concat)
return hidden_states

def layer_norm(self, hidden_states, layernorm_weight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def build_decoder(
hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states, self.seq_len)
hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)

Expand Down

0 comments on commit 643d16b

Please sign in to comment.