Skip to content

Commit

Permalink
[LLM INFER] not use gemm_dequant default and fix bug (#9498)
Browse files Browse the repository at this point in the history
* not use gemm_dequant default and fix bug
  • Loading branch information
yuanlehome authored Nov 26, 2024
1 parent 4b02477 commit f5ca96e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
13 changes: 6 additions & 7 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class AvxConfig:

@dataclass
class SpeculateConfig:
speculate_max_draft_token_num: int = (1,)
speculate_max_draft_token_num: int = 5
speculate_method: str = None


Expand Down Expand Up @@ -1690,7 +1690,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
self.quant_round_type = config.quant_round_type
self.quant_max_bound = config.quant_max_bound
self.quant_min_bound = config.quant_min_bound
# self.use_gemm_dequant = False
self.use_gemm_dequant = False

self.qkv_out_scales = []
self.linear_out_scales = []
Expand Down Expand Up @@ -1928,7 +1928,6 @@ def compute_qkv_linear(self, ln_out, i):
if paddle.is_compiled_with_rocm():
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i])
else:
# TODO: add gemm_dequant after qkv_out
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
return qkv_out

Expand Down Expand Up @@ -2033,13 +2032,13 @@ def compute_out_linear(self, fmha_out, i):
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i])
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
else:
try:
if self.use_gemm_dequant:
from paddlenlp_ops import gemm_dequant

out_linear_out = gemm_dequant(
fmha_out, self.linear_weights[i], self.linear_out_scales[i], self._dtype
)
except:
else:
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True)
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
return out_linear_out
Expand Down Expand Up @@ -2094,11 +2093,11 @@ def compute_ffn2(self, ffn1_out, i):
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i])
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
else:
try:
if self.use_gemm_dequant:
from paddlenlp_ops import gemm_dequant

ffn2_out = gemm_dequant(ffn1_out, self.ffn2_weights[i], self.ffn2_out_scales[i], self._dtype)
except:
else:
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True)
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
return ffn2_out
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ def set_state_dict(self, state_dict):
self.embed_tokens.weight.set_value(embed_tokens_weight)
self.norm.weight.set_value(norm_weight)

if self.use_weight_only:
logger.info("weight only is enabled")
for idx in range(self.num_layers):
logger.info(f"set state for layer {idx}")
unfused_state_dict = {}
ln_scale = paddle.to_tensor(state_dict["qwen2_moe.layers.{}.input_layernorm.weight".format(idx)]).cast(
self.transformer_block.ln_scales[idx].dtype
Expand Down

0 comments on commit f5ca96e

Please sign in to comment.