diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 972f5ee633bbb..0db16ffb7e20b 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,10 +125,18 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - result[pr_index] = - cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; - result[ls_index] = - cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; + if (sign == 1) { + result[pr_index] = cos_value[pr_index] * p0; + result[pr_index] -= sin_value[pr_index] * p1; + + result[ls_index] = sin_value[ls_index] * p0; + result[ls_index] += cos_value[ls_index] * p1; + } else if (sign == -1) { + result[pr_index] = + cos_value[pr_index] * p0 + sin_value[ls_index] * p1; + result[ls_index] = + cos_value[ls_index] * p1 - sin_value[pr_index] * p0; + } store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]);