diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 6cd109bda9e3..7611fd961ab6 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -647,20 +647,6 @@ def __init__(self, config): def forward(self, x): if self.fuse_attention_ffn: - # FIXME(yangjianbang): use paddle's native swiglu - if get_env_device() == "xpu": - try: - import paddle_xpu_nn # noqa: F821 - - out = self.gate_up_fused_proj(x) - out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) - out = self.down_proj(out) - return out - except ImportError: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) - return out - x = swiglu(self.gate_up_fused_proj(x)) else: x = swiglu(self.gate_proj(x), self.up_proj(x))