From f1cf8846aa0b9b1715a2424e8475b3e7164e619a Mon Sep 17 00:00:00 2001 From: "haowen.han@mthreads.com" Date: Mon, 13 May 2024 13:29:57 +0800 Subject: [PATCH] Revert "fix fused_rope diff (#60217) (#60593)" This reverts commit 97b65c7d43f4ce28366d17b907051b1ef3d9f643. --- paddle/phi/kernels/fusion/gpu/fused_rope_utils.h | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 0db16ffb7e20bc..972f5ee633bbb0 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,18 +125,10 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - if (sign == 1) { - result[pr_index] = cos_value[pr_index] * p0; - result[pr_index] -= sin_value[pr_index] * p1; - - result[ls_index] = sin_value[ls_index] * p0; - result[ls_index] += cos_value[ls_index] * p1; - } else if (sign == -1) { - result[pr_index] = - cos_value[pr_index] * p0 + sin_value[ls_index] * p1; - result[ls_index] = - cos_value[ls_index] * p1 - sin_value[pr_index] * p0; - } + result[pr_index] = + cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; + result[ls_index] = + cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]);