diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 0db16ffb7e20bc..972f5ee633bbb0 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,18 +125,10 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - if (sign == 1) { - result[pr_index] = cos_value[pr_index] * p0; - result[pr_index] -= sin_value[pr_index] * p1; - - result[ls_index] = sin_value[ls_index] * p0; - result[ls_index] += cos_value[ls_index] * p1; - } else if (sign == -1) { - result[pr_index] = - cos_value[pr_index] * p0 + sin_value[ls_index] * p1; - result[ls_index] = - cos_value[ls_index] * p1 - sin_value[pr_index] * p0; - } + result[pr_index] = + cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; + result[ls_index] = + cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]);