From c4c4bf33bf2ee684242e89f93947eb026b223284 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 17 Aug 2023 10:33:42 +0000 Subject: [PATCH 1/8] add rotate_half in fused_rope --- paddle/phi/api/yaml/fused_backward.yaml | 4 +- paddle/phi/api/yaml/fused_ops.yaml | 2 +- paddle/phi/infermeta/backward.cc | 1 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/multiary.cc | 1 + paddle/phi/infermeta/multiary.h | 1 + .../fusion/gpu/fused_rope_grad_kernel.cu | 40 +++++-- .../kernels/fusion/gpu/fused_rope_kernel.cu | 40 +++++-- .../phi/kernels/fusion/gpu/fused_rope_utils.h | 109 +++++++++++++++-- .../fused_rotary_position_embedding.py | 9 +- .../test_fused_rotary_position_embedding.py | 112 +++++++++++++++--- 11 files changed, 261 insertions(+), 59 deletions(-) diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 5f49e790e2550b..44b2722bd46476 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -17,8 +17,8 @@ support_dygraph_mode : true - backward_op : fused_rotary_position_embedding_grad - forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) - args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) + forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) + args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad infer_meta : diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index a45759b4000051..9d3d9c2f407814 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -148,7 +148,7 @@ optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index - op : fused_rotary_position_embedding - args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) + args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, bool use_neox_rotary_style = true) output : Tensor(out_q), Tensor(out_k), Tensor(out_v) infer_meta : func : FusedRopeInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a1b1be2861b3d1..a5cedd232e375c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1222,6 +1222,7 @@ void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, + bool use_neox_rotary_style, MetaTensor* dq, MetaTensor* dk, MetaTensor* dv) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 184c8d10fb0748..fd04a00e5d64c4 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -189,6 +189,7 @@ void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, + bool use_neox_rotary_style, MetaTensor* dq, MetaTensor* dk, MetaTensor* dv); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 06232c06907169..16ece2cdfdfb69 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3848,6 +3848,7 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, MetaTensor* out_v) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e1bd5bd1fe1e85..b944742a978589 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -769,6 +769,7 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, MetaTensor* out_v); diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 442317eb53d980..5ade9511736a43 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -30,6 +30,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, const DenseTensor& dout_q, const paddle::optional& dout_k, const paddle::optional& dout_v, + bool use_neox_rotary_style, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { @@ -89,18 +90,33 @@ void FusedRopeGradKernel(const Context& dev_ctx, } int sign = -1; - VectorizedFusedRopeKernel - <<>>(ins_data, - sin_cos_data, - flag_sin_cos, - sign, - batch_size, - seq_len, - num_heads, - head_dim, - outs_data, - num_inputs, - div_c); + if (use_neox_rotary_style) { + VectorizedFusedRopeWithRotateEveryTwoKernel + <<>>(ins_data, + sin_cos_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } else { + VectorizedFusedRopeWithRotateHalfKernel + <<>>(ins_data, + sin_cos_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index f6dcbc2a9038f0..5c248cebe84676 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -30,6 +30,7 @@ void FusedRopeKernel(const Context& dev_ctx, const paddle::optional& v, const paddle::optional& sin, const paddle::optional& cos, + bool use_neox_rotary_style, DenseTensor* out_q, DenseTensor* out_k, DenseTensor* out_v) { @@ -126,18 +127,33 @@ void FusedRopeKernel(const Context& dev_ctx, } int sign = 1; - VectorizedFusedRopeKernel - <<>>(ins_data, - sin_cos_data, - flag_sin_cos, - sign, - batch_size, - seq_len, - num_heads, - head_dim, - outs_data, - num_inputs, - div_c); + if (use_neox_rotary_style) { + VectorizedFusedRopeWithRotateEveryTwoKernel + <<>>(ins_data, + sin_cos_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } else { + VectorizedFusedRopeWithRotateHalfKernel + <<>>(ins_data, + sin_cos_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } } } // namespace fusion } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 54ffba19e60c0e..8c30ddcd0f7adc 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -20,17 +20,18 @@ namespace phi { namespace fusion { template -__global__ void VectorizedFusedRopeKernel(phi::Array ins_data, - phi::Array sin_cos_data, - bool flag_sin_cos, - int sign, - int64_t batch_size, - int64_t seq_len, - int64_t num_heads, - int64_t head_dim, - phi::Array outs_data, - int num_inputs, - MPType div_c) { +__global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( + phi::Array ins_data, + phi::Array sin_cos_data, + bool flag_sin_cos, + int sign, + int64_t batch_size, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { int64_t index = (static_cast(blockIdx.x) * static_cast(blockDim.x) + threadIdx.x) * @@ -102,5 +103,91 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, } } +template +__global__ void VectorizedFusedRopeWithRotateHalfKernel( + phi::Array ins_data, + phi::Array sin_cos_data, + bool flag_sin_cos, + int sign, + int64_t batch_size, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { + int64_t index = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + threadIdx.x) * + VecSize; + int64_t stride = static_cast(gridDim.x) * + static_cast(blockDim.x) * VecSize; + int64_t size = batch_size * seq_len * num_heads * head_dim; + MPType sin_value[VecSize]; + MPType cos_value[VecSize]; + MPType result[VecSize]; + T store[VecSize]; + using VecType = phi::AlignedVector; + constexpr int kVectorsPerThread = VecSize / 2; + + for (; index < size; index += stride) { + if (flag_sin_cos) { +#pragma unroll + for (int64_t nx = 0; nx < VecSize; ++nx) { + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + int64_t pos_head = index_wc % head_dim; + int64_t index_sc = pos_seq * head_dim + pos_head; + const T* sin_input = sin_cos_data[0] + index_sc; + const T* cos_input = sin_cos_data[1] + index_sc; + + sin_value[nx] = static_cast(sin_input[0]); + cos_value[nx] = static_cast(cos_input[0]); + } + } else { +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + // get sin_index and cos_index + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); + MPType indicses = + static_cast(1) / + pow(static_cast(10000), idx * static_cast(div_c)); + MPType value = pos_seq * indicses; + sin_value[nx] = sin(value); + cos_value[nx] = cos(value); + } + } + + // use rotate_half mode + int stride_r = head_dim / 2; +#pragma unroll + for (int iter = 0; iter < 3; iter++) { + if (iter > num_inputs) break; + // get value_index and rotate_half_index + int index_v = index; + int index_r = (index % head_dim) < stride_r ? (index + stride_r) + : (index - stride_r); + MPType sign_r = (index % head_dim) < stride_r ? static_cast(-1) + : static_cast(1); + const T* input_v = ins_data[iter] + index_v; + const T* input_r = ins_data[iter] + index_r; + VecType* out = reinterpret_cast(outs_data[iter] + index); + +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + MPType p0 = static_cast(input_v[nx]); + MPType p1 = static_cast(input_r[nx]); + + result[nx] = cos_value[nx] * p0 + sign * sign_r * sin_value[nx] * p1; + + store[nx] = static_cast(result[nx]); + } + out[0] = *(reinterpret_cast(store)); + } + } +} + } // namespace fusion } // namespace phi diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index e05ae63f07807e..f4b559a42cfeff 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -17,7 +17,9 @@ from paddle.framework import in_dynamic_mode -def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): +def fused_rotary_position_embedding( + q, k=None, v=None, sin=None, cos=None, use_neox_rotary_style=True +): r""" Fused rotary position embedding. @@ -27,6 +29,7 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. + use_neox_rotary_style(optional|bool): Use "rotate_every_two" when use_neox_rotary_style is True, use "ratate_half" when use_neox_rotary_style is False. Default True. Returns: out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . @@ -52,7 +55,9 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos) """ if in_dynamic_mode(): - return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos) + return _C_ops.fused_rotary_position_embedding( + q, k, v, sin, cos, use_neox_rotary_style + ) raise RuntimeError( "This feature is currently supported only in dynamic mode and with CUDAPlace." diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 9842fbf1f4ee8c..48aeed1845c28e 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -41,6 +41,24 @@ def mult_qkv(value, cos_tensor, sin_tensor): return query +def mult_qkv_rotate_half(value, cos_tensor, sin_tensor): + rotate_half_q = paddle.reshape( + paddle.concat( + [ + -value[..., value.shape[-1] // 2 :], + value[..., : value.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(value), + ) + query = paddle.add( + paddle.multiply(value, cos_tensor), + paddle.multiply(rotate_half_q, sin_tensor), + ) + return query + + def get_sin_cos_tensor(seq_len, head_dim, sign): pos_seq = paddle.arange(0, seq_len, 1, dtype="float32") indices = paddle.arange(0, head_dim, 2, dtype="float32") @@ -74,22 +92,37 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): return tensor_sin, tensor_cos -def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): +def paddle_fused_rotary_position_embedding( + init_q, init_k, init_v, use_neox_rotary_style=True +): # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # to [batch_size, num_heads, seq_len, head_dim] q, k, v = deal_qkv(init_q, init_k, init_v) - sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) + if use_neox_rotary_style: + sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) - # permute sin, cos from [1, seq_len, 1, head_dim] - # to [1, 1, seq_len, head_dim] - perm = [0, 2, 1, 3] - sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) - cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) + # permute sin, cos from [1, seq_len, 1, head_dim] + # to [1, 1, seq_len, head_dim] + perm = [0, 2, 1, 3] + sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) + cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) - query = mult_qkv(q, cos_tensor, sin_tensor) - value = mult_qkv(v, cos_tensor, sin_tensor) - key = mult_qkv(k, cos_tensor, sin_tensor) + query = mult_qkv(q, cos_tensor, sin_tensor) + value = mult_qkv(v, cos_tensor, sin_tensor) + key = mult_qkv(k, cos_tensor, sin_tensor) + else: + sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], 1) + + # permute sin, cos from [1, seq_len, 1, head_dim] + # to [1, 1, seq_len, head_dim] + perm = [0, 2, 1, 3] + sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) + cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) + + query = mult_qkv_rotate_half(q, cos_tensor, sin_tensor) + value = mult_qkv_rotate_half(v, cos_tensor, sin_tensor) + key = mult_qkv_rotate_half(k, cos_tensor, sin_tensor) # permute the result back to [batch_size, seq_len, num_heads, head_dim] r_query, r_key, r_value = deal_qkv(query, key, value) @@ -112,7 +145,9 @@ def get_paddle_tensor(self): tmp.stop_gradient = False return tmp - def get_forward_backward(self, rope_function, seed, flag=0): + def get_forward_backward( + self, rope_function, seed, flag=0, use_neox_rotary_style=True + ): paddle.disable_static() paddle.seed(seed) fw = [] @@ -120,15 +155,35 @@ def get_forward_backward(self, rope_function, seed, flag=0): tensor_q = self.get_paddle_tensor() tensor_k = self.get_paddle_tensor() tensor_v = self.get_paddle_tensor() - if flag: - tensor_sin, tensor_cos = get_sin_cos_tensor( - tensor_q.shape[1], tensor_q.shape[3], 1 - ) - out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos - ) + if use_neox_rotary_style: + if flag: + tensor_sin, tensor_cos = get_sin_cos_tensor( + tensor_q.shape[1], tensor_q.shape[3], 1 + ) + out_q, out_k, out_v = rope_function( + tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos + ) + else: + out_q, out_k, out_v = rope_function( + tensor_q, tensor_k, tensor_v + ) else: - out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v) + if flag: + tensor_sin, tensor_cos = get_sin_cos_tensor( + tensor_q.shape[1], tensor_q.shape[3], 1 + ) + out_q, out_k, out_v = rope_function( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + use_neox_rotary_style=False, + ) + else: + out_q, out_k, out_v = rope_function( + tensor_q, tensor_k, tensor_v, use_neox_rotary_style=False + ) fw.append(out_q) fw.append(out_k) @@ -176,6 +231,25 @@ def test_fused_rope_with_sin_cos(self): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) + def test_fused_rope_rotate_half(self): + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, + seed=self.seed, + use_neox_rotary_style=False, + ) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + use_neox_rotary_style=False, + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + def test_error(self): paddle.enable_static() with self.assertRaises(RuntimeError): From 8cc3cf563fa7f538c020f7bd01eca50264dbd27f Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 24 Aug 2023 11:57:06 +0000 Subject: [PATCH 2/8] modified the fused_rope op based on the review --- .../phi/kernels/fusion/gpu/fused_rope_utils.h | 117 +++++++++--------- .../fused_rotary_position_embedding.py | 2 +- .../test_fused_rotary_position_embedding.py | 23 ++-- 3 files changed, 70 insertions(+), 72 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 8c30ddcd0f7adc..8cc7f968df2c6e 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -19,6 +19,49 @@ namespace phi { namespace fusion { +template +__device__ void VectorizedGetSinCos(phi::Array sin_cos_data, + bool flag_sin_cos, + int64_t index, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + MPType* out_sin, + MPType* out_cos, + MPType div_c) { + MPType* sin_value = out_sin; + MPType* cos_value = out_cos; + + if (flag_sin_cos) { +#pragma unroll + for (int64_t nx = 0; nx < VecSize; ++nx) { + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + int64_t pos_head = index_wc % head_dim; + int64_t index_sc = pos_seq * head_dim + pos_head; + const T* sin_input = sin_cos_data[0] + index_sc; + const T* cos_input = sin_cos_data[1] + index_sc; + + sin_value[nx] = static_cast(sin_input[0]); + cos_value[nx] = static_cast(cos_input[0]); + } + } else { +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + // get sin_index and cos_index + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); + MPType indicses = + static_cast(1) / + pow(static_cast(10000), idx * static_cast(div_c)); + MPType value = pos_seq * indicses; + sin_value[nx] = sin(value); + cos_value[nx] = cos(value); + } + } +} + template __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( phi::Array ins_data, @@ -47,34 +90,15 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( constexpr int kVectorsPerThread = VecSize / 2; for (; index < size; index += stride) { - if (flag_sin_cos) { -#pragma unroll - for (int64_t nx = 0; nx < VecSize; ++nx) { - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - int64_t pos_head = index_wc % head_dim; - int64_t index_sc = pos_seq * head_dim + pos_head; - const T* sin_input = sin_cos_data[0] + index_sc; - const T* cos_input = sin_cos_data[1] + index_sc; - - sin_value[nx] = static_cast(sin_input[0]); - cos_value[nx] = static_cast(cos_input[0]); - } - } else { -#pragma unroll - for (int nx = 0; nx < VecSize; ++nx) { - // get sin_index and cos_index - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); - MPType indicses = - static_cast(1) / - pow(static_cast(10000), idx * static_cast(div_c)); - MPType value = pos_seq * indicses; - sin_value[nx] = sin(value); - cos_value[nx] = cos(value); - } - } + VectorizedGetSinCos(sin_cos_data, + flag_sin_cos, + index, + seq_len, + num_heads, + head_dim, + sin_value, + cos_value, + div_c); #pragma unroll for (int iter = 0; iter < 3; iter++) { @@ -131,34 +155,15 @@ __global__ void VectorizedFusedRopeWithRotateHalfKernel( constexpr int kVectorsPerThread = VecSize / 2; for (; index < size; index += stride) { - if (flag_sin_cos) { -#pragma unroll - for (int64_t nx = 0; nx < VecSize; ++nx) { - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - int64_t pos_head = index_wc % head_dim; - int64_t index_sc = pos_seq * head_dim + pos_head; - const T* sin_input = sin_cos_data[0] + index_sc; - const T* cos_input = sin_cos_data[1] + index_sc; - - sin_value[nx] = static_cast(sin_input[0]); - cos_value[nx] = static_cast(cos_input[0]); - } - } else { -#pragma unroll - for (int nx = 0; nx < VecSize; ++nx) { - // get sin_index and cos_index - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); - MPType indicses = - static_cast(1) / - pow(static_cast(10000), idx * static_cast(div_c)); - MPType value = pos_seq * indicses; - sin_value[nx] = sin(value); - cos_value[nx] = cos(value); - } - } + VectorizedGetSinCos(sin_cos_data, + flag_sin_cos, + index, + seq_len, + num_heads, + head_dim, + sin_value, + cos_value, + div_c); // use rotate_half mode int stride_r = head_dim / 2; diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index f4b559a42cfeff..9d0cf9c3978c6d 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -29,7 +29,7 @@ def fused_rotary_position_embedding( v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. - use_neox_rotary_style(optional|bool): Use "rotate_every_two" when use_neox_rotary_style is True, use "ratate_half" when use_neox_rotary_style is False. Default True. + use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True. Returns: out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 48aeed1845c28e..c7ec0de633d612 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -99,27 +99,20 @@ def paddle_fused_rotary_position_embedding( # to [batch_size, num_heads, seq_len, head_dim] q, k, v = deal_qkv(init_q, init_k, init_v) - if use_neox_rotary_style: - sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) + sign = -1 if use_neox_rotary_style else 1 + sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], sign) - # permute sin, cos from [1, seq_len, 1, head_dim] - # to [1, 1, seq_len, head_dim] - perm = [0, 2, 1, 3] - sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) - cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) + # permute sin, cos from [1, seq_len, 1, head_dim] + # to [1, 1, seq_len, head_dim] + perm = [0, 2, 1, 3] + sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) + cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) + if use_neox_rotary_style: query = mult_qkv(q, cos_tensor, sin_tensor) value = mult_qkv(v, cos_tensor, sin_tensor) key = mult_qkv(k, cos_tensor, sin_tensor) else: - sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], 1) - - # permute sin, cos from [1, seq_len, 1, head_dim] - # to [1, 1, seq_len, head_dim] - perm = [0, 2, 1, 3] - sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) - cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) - query = mult_qkv_rotate_half(q, cos_tensor, sin_tensor) value = mult_qkv_rotate_half(v, cos_tensor, sin_tensor) key = mult_qkv_rotate_half(k, cos_tensor, sin_tensor) From 2459faa49c7e9ca5c2e18abe73b462de3b997bf7 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 31 Aug 2023 03:35:41 +0000 Subject: [PATCH 3/8] add position_ids in fused_rope --- paddle/phi/api/yaml/fused_backward.yaml | 6 +- paddle/phi/api/yaml/fused_ops.yaml | 4 +- paddle/phi/infermeta/backward.cc | 1 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/multiary.cc | 1 + paddle/phi/infermeta/multiary.h | 1 + .../fusion/gpu/fused_rope_grad_kernel.cu | 13 ++++ .../kernels/fusion/gpu/fused_rope_kernel.cu | 13 ++++ .../phi/kernels/fusion/gpu/fused_rope_utils.h | 30 ++++++++- .../fused_rotary_position_embedding.py | 11 +++- .../test_fused_rotary_position_embedding.py | 63 ++++++++++++++++--- 11 files changed, 126 insertions(+), 18 deletions(-) diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 44b2722bd46476..8dfa117d44c976 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -17,10 +17,10 @@ support_dygraph_mode : true - backward_op : fused_rotary_position_embedding_grad - forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) - args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style) + forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) + args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) - optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad + optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad infer_meta : func : FusedRopeGradInferMeta kernel : diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 9d3d9c2f407814..44fd2c951b761d 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -148,11 +148,11 @@ optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index - op : fused_rotary_position_embedding - args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, bool use_neox_rotary_style = true) + args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true) output : Tensor(out_q), Tensor(out_k), Tensor(out_v) infer_meta : func : FusedRopeInferMeta - optional : k,v,sin,cos, out_k, out_v + optional : k, v, sin, cos, position_ids, out_k, out_v kernel : func : fused_rotary_position_embedding data_type : q diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a5cedd232e375c..9a2f68f3bfb2ee 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1219,6 +1219,7 @@ void IndexPutGradInferMeta(const MetaTensor& x, void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index fd04a00e5d64c4..08ee5aa8370bf0 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -186,6 +186,7 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 16ece2cdfdfb69..7cdadef98c5258 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3848,6 +3848,7 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index b944742a978589..af9720702b9663 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -769,6 +769,7 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 5ade9511736a43..8ef36e284d3659 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -27,6 +27,7 @@ template void FusedRopeGradKernel(const Context& dev_ctx, const paddle::optional& sin, const paddle::optional& cos, + const paddle::optional& position_ids, const DenseTensor& dout_q, const paddle::optional& dout_k, const paddle::optional& dout_v, @@ -59,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; + const int64_t* position_ids_data; ins_data[0] = dout_q.data(); outs_data[0] = dq->data(); @@ -89,12 +91,21 @@ void FusedRopeGradKernel(const Context& dev_ctx, flag_sin_cos = true; } + bool flag_position_ids = false; + if (position_ids.get_ptr()) { + position_ids_data = position_ids->data(); + + flag_position_ids = true; + } + int sign = -1; if (use_neox_rotary_style) { VectorizedFusedRopeWithRotateEveryTwoKernel <<>>(ins_data, sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, sign, batch_size, seq_len, @@ -107,7 +118,9 @@ void FusedRopeGradKernel(const Context& dev_ctx, VectorizedFusedRopeWithRotateHalfKernel <<>>(ins_data, sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, sign, batch_size, seq_len, diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index 5c248cebe84676..30f01e138af0ee 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -30,6 +30,7 @@ void FusedRopeKernel(const Context& dev_ctx, const paddle::optional& v, const paddle::optional& sin, const paddle::optional& cos, + const paddle::optional& position_ids, bool use_neox_rotary_style, DenseTensor* out_q, DenseTensor* out_k, @@ -60,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; + const int64_t* position_ids_data; ins_data[0] = q.data(); outs_data[0] = out_q->data(); @@ -126,12 +128,21 @@ void FusedRopeKernel(const Context& dev_ctx, flag_sin_cos = true; } + bool flag_position_ids = false; + if (position_ids.get_ptr()) { + position_ids_data = position_ids->data(); + + flag_position_ids = true; + } + int sign = 1; if (use_neox_rotary_style) { VectorizedFusedRopeWithRotateEveryTwoKernel <<>>(ins_data, sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, sign, batch_size, seq_len, @@ -144,7 +155,9 @@ void FusedRopeKernel(const Context& dev_ctx, VectorizedFusedRopeWithRotateHalfKernel <<>>(ins_data, sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, sign, batch_size, seq_len, diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 8cc7f968df2c6e..8fa14383cb84bd 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -21,7 +21,9 @@ namespace fusion { template __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, + const int64_t* position_ids_data, bool flag_sin_cos, + bool flag_position_ids, int64_t index, int64_t seq_len, int64_t num_heads, @@ -36,7 +38,15 @@ __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, #pragma unroll for (int64_t nx = 0; nx < VecSize; ++nx) { int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); + int64_t pos_seq_ori = index_wc / (num_heads * head_dim); + int64_t pos_seq; + if (flag_position_ids) { + int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim); + int64_t index_ids = pos_bs * seq_len + pos_seq_ori; + pos_seq = position_ids_data[index_ids]; + } else { + pos_seq = pos_seq_ori; + } int64_t pos_head = index_wc % head_dim; int64_t index_sc = pos_seq * head_dim + pos_head; const T* sin_input = sin_cos_data[0] + index_sc; @@ -50,7 +60,15 @@ __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, for (int nx = 0; nx < VecSize; ++nx) { // get sin_index and cos_index int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); + int64_t pos_seq_ori = index_wc / (num_heads * head_dim); + int64_t pos_seq; + if (flag_position_ids) { + int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim); + int64_t index_ids = pos_bs * seq_len + pos_seq_ori; + pos_seq = position_ids_data[index_ids]; + } else { + pos_seq = pos_seq_ori; + } MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); MPType indicses = static_cast(1) / @@ -66,7 +84,9 @@ template __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( phi::Array ins_data, phi::Array sin_cos_data, + const int64_t* position_ids_data, bool flag_sin_cos, + bool flag_position_ids, int sign, int64_t batch_size, int64_t seq_len, @@ -91,7 +111,9 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( for (; index < size; index += stride) { VectorizedGetSinCos(sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, index, seq_len, num_heads, @@ -131,7 +153,9 @@ template __global__ void VectorizedFusedRopeWithRotateHalfKernel( phi::Array ins_data, phi::Array sin_cos_data, + const int64_t* position_ids_data, bool flag_sin_cos, + bool flag_position_ids, int sign, int64_t batch_size, int64_t seq_len, @@ -156,7 +180,9 @@ __global__ void VectorizedFusedRopeWithRotateHalfKernel( for (; index < size; index += stride) { VectorizedGetSinCos(sin_cos_data, + position_ids_data, flag_sin_cos, + flag_position_ids, index, seq_len, num_heads, diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 9d0cf9c3978c6d..8c142d4757b552 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -18,7 +18,13 @@ def fused_rotary_position_embedding( - q, k=None, v=None, sin=None, cos=None, use_neox_rotary_style=True + q, + k=None, + v=None, + sin=None, + cos=None, + position_ids=None, + use_neox_rotary_style=True, ): r""" Fused rotary position embedding. @@ -29,6 +35,7 @@ def fused_rotary_position_embedding( v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. + position_ids (optional|Tensor): The input tensor. The data type is int64. The shape if position_ids must be [batch_size, seq_len]. use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True. Returns: @@ -56,7 +63,7 @@ def fused_rotary_position_embedding( """ if in_dynamic_mode(): return _C_ops.fused_rotary_position_embedding( - q, k, v, sin, cos, use_neox_rotary_style + q, k, v, sin, cos, position_ids, use_neox_rotary_style ) raise RuntimeError( diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index c7ec0de633d612..1950a950ebc6e1 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -93,7 +93,7 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): def paddle_fused_rotary_position_embedding( - init_q, init_k, init_v, use_neox_rotary_style=True + init_q, init_k, init_v, position_ids=None, use_neox_rotary_style=True ): # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # to [batch_size, num_heads, seq_len, head_dim] @@ -102,8 +102,16 @@ def paddle_fused_rotary_position_embedding( sign = -1 if use_neox_rotary_style else 1 sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], sign) - # permute sin, cos from [1, seq_len, 1, head_dim] - # to [1, 1, seq_len, head_dim] + if position_ids is not None: + sin_tensor = sin_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] + cos_tensor = cos_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] + sin_tensor = sin_tensor[position_ids].unsqueeze( + 2 + ) # [bs, seq_len, 1, dim] + cos_tensor = cos_tensor[position_ids].unsqueeze( + 2 + ) # [bs, seq_len, 1, dim] + perm = [0, 2, 1, 3] sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) @@ -128,7 +136,7 @@ def paddle_fused_rotary_position_embedding( ) class TestFusedRotaryPositionEmbedding(unittest.TestCase): def setUp(self): - self.shape = [1, 8, 2, 16] + self.shape = [2, 8, 2, 16] self.dtype = 'float32' self.training = True self.seed = 1203 @@ -139,7 +147,12 @@ def get_paddle_tensor(self): return tmp def get_forward_backward( - self, rope_function, seed, flag=0, use_neox_rotary_style=True + self, + rope_function, + seed, + flag=False, + use_neox_rotary_style=True, + position_ids=None, ): paddle.disable_static() paddle.seed(seed) @@ -154,11 +167,16 @@ def get_forward_backward( tensor_q.shape[1], tensor_q.shape[3], 1 ) out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + position_ids=position_ids, ) else: out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v + tensor_q, tensor_k, tensor_v, position_ids=position_ids ) else: if flag: @@ -171,11 +189,16 @@ def get_forward_backward( tensor_v, tensor_sin, tensor_cos, + position_ids=position_ids, use_neox_rotary_style=False, ) else: out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v, use_neox_rotary_style=False + tensor_q, + tensor_k, + tensor_v, + position_ids=position_ids, + use_neox_rotary_style=False, ) fw.append(out_q) @@ -214,7 +237,7 @@ def test_fused_rope_with_sin_cos(self): paddle_fused_rotary_position_embedding, seed=self.seed ) f_fw, f_bw = self.get_forward_backward( - fused_rotary_position_embedding, seed=self.seed, flag=1 + fused_rotary_position_embedding, seed=self.seed, flag=True ) for i in range(len(p_fw)): np.testing.assert_allclose( @@ -243,6 +266,28 @@ def test_fused_rope_rotate_half(self): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) + def test_fused_rope_position_ids(self): + position_ids = paddle.to_tensor( + [[7, 5, 4, 6, 3, 1, 2, 0], [3, 1, 4, 0, 7, 6, 5, 2]] + ) + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, + seed=self.seed, + position_ids=position_ids, + ) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + position_ids=position_ids, + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + def test_error(self): paddle.enable_static() with self.assertRaises(RuntimeError): From 2e8be3338040ffc483adc16f0c2f94f1e1b9a550 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 31 Aug 2023 06:17:39 +0000 Subject: [PATCH 4/8] modified fused_rope according to review --- .../fusion/gpu/fused_rope_grad_kernel.cu | 13 ++-- .../kernels/fusion/gpu/fused_rope_kernel.cu | 66 +++++++++++++------ .../phi/kernels/fusion/gpu/fused_rope_utils.h | 17 +---- .../fused_rotary_position_embedding.py | 12 ++-- .../test_fused_rotary_position_embedding.py | 1 + 5 files changed, 60 insertions(+), 49 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 8ef36e284d3659..70ea70912f6397 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -60,7 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; - const int64_t* position_ids_data; + const int64_t* position_ids_data = NULL; ins_data[0] = dout_q.data(); outs_data[0] = dq->data(); @@ -89,13 +89,10 @@ void FusedRopeGradKernel(const Context& dev_ctx, sin_cos_data[1] = cos->data(); flag_sin_cos = true; - } - - bool flag_position_ids = false; - if (position_ids.get_ptr()) { - position_ids_data = position_ids->data(); - flag_position_ids = true; + if (position_ids.get_ptr()) { + position_ids_data = position_ids->data(); + } } int sign = -1; @@ -105,7 +102,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, sign, batch_size, seq_len, @@ -120,7 +116,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, sign, batch_size, seq_len, diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index 30f01e138af0ee..6e032211cc6a09 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -61,7 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; - const int64_t* position_ids_data; + const int64_t* position_ids_data = NULL; ins_data[0] = q.data(); outs_data[0] = out_q->data(); @@ -112,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx, "The batch_size and num_heads of sin and cos must be 1.")); } int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0; - PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim && - sin_dims[sin_seq_len_dim] == seq_len), - true, - phi::errors::InvalidArgument( - "The seq_len and head_dim of sin and cos " - "must be the same as those of q. But recieved sin's " - "shape is {%s}, q's shape is {%s}.", - sin_dims, - q.dims())); + + if (position_ids.get_ptr()) { + PADDLE_ENFORCE_EQ( + (sin_dims[dims_size - 1] == head_dim && + sin_dims[sin_seq_len_dim] >= seq_len), + true, + phi::errors::InvalidArgument( + "The seq_len of sin and cos must be greater than or equal to " + "this of q. The head_dim of sin and cos must be the same as this " + "of q. But recieved sin's " + "shape is {%s}, q's shape is {%s}.", + sin_dims, + q.dims())); + + auto position_ids_dims = position_ids.get_ptr()->dims(); + PADDLE_ENFORCE_EQ(position_ids_dims.size(), + 2, + phi::errors::InvalidArgument( + "The dims of position_ids is expected to " + "be 2, but recieved %d.", + position_ids_dims.size())); + + PADDLE_ENFORCE_EQ( + (position_ids_dims[0] == batch_size && + position_ids_dims[1] == seq_len), + true, + phi::errors::InvalidArgument( + "The batch_size and seq_len of position_ids must be the same as " + "those of q. But recieved position_ids's " + "shape is {%s}, q's shape is {%s}.", + position_ids_dims, + q.dims())); + + position_ids_data = position_ids->data(); + } else { + PADDLE_ENFORCE_EQ( + (sin_dims[dims_size - 1] == head_dim && + sin_dims[sin_seq_len_dim] == seq_len), + true, + phi::errors::InvalidArgument( + "The seq_len and head_dim of sin and cos " + "must be the same as those of q. But recieved sin's " + "shape is {%s}, q's shape is {%s}.", + sin_dims, + q.dims())); + } sin_cos_data[0] = sin->data(); sin_cos_data[1] = cos->data(); @@ -128,13 +165,6 @@ void FusedRopeKernel(const Context& dev_ctx, flag_sin_cos = true; } - bool flag_position_ids = false; - if (position_ids.get_ptr()) { - position_ids_data = position_ids->data(); - - flag_position_ids = true; - } - int sign = 1; if (use_neox_rotary_style) { VectorizedFusedRopeWithRotateEveryTwoKernel @@ -142,7 +172,6 @@ void FusedRopeKernel(const Context& dev_ctx, sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, sign, batch_size, seq_len, @@ -157,7 +186,6 @@ void FusedRopeKernel(const Context& dev_ctx, sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, sign, batch_size, seq_len, diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 8fa14383cb84bd..972f5ee633bbb0 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -23,7 +23,6 @@ template __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, const int64_t* position_ids_data, bool flag_sin_cos, - bool flag_position_ids, int64_t index, int64_t seq_len, int64_t num_heads, @@ -40,7 +39,7 @@ __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); int64_t pos_seq_ori = index_wc / (num_heads * head_dim); int64_t pos_seq; - if (flag_position_ids) { + if (position_ids_data) { int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim); int64_t index_ids = pos_bs * seq_len + pos_seq_ori; pos_seq = position_ids_data[index_ids]; @@ -60,15 +59,7 @@ __device__ void VectorizedGetSinCos(phi::Array sin_cos_data, for (int nx = 0; nx < VecSize; ++nx) { // get sin_index and cos_index int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq_ori = index_wc / (num_heads * head_dim); - int64_t pos_seq; - if (flag_position_ids) { - int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim); - int64_t index_ids = pos_bs * seq_len + pos_seq_ori; - pos_seq = position_ids_data[index_ids]; - } else { - pos_seq = pos_seq_ori; - } + int64_t pos_seq = index_wc / (num_heads * head_dim); MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); MPType indicses = static_cast(1) / @@ -86,7 +77,6 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( phi::Array sin_cos_data, const int64_t* position_ids_data, bool flag_sin_cos, - bool flag_position_ids, int sign, int64_t batch_size, int64_t seq_len, @@ -113,7 +103,6 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( VectorizedGetSinCos(sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, index, seq_len, num_heads, @@ -155,7 +144,6 @@ __global__ void VectorizedFusedRopeWithRotateHalfKernel( phi::Array sin_cos_data, const int64_t* position_ids_data, bool flag_sin_cos, - bool flag_position_ids, int sign, int64_t batch_size, int64_t seq_len, @@ -182,7 +170,6 @@ __global__ void VectorizedFusedRopeWithRotateHalfKernel( VectorizedGetSinCos(sin_cos_data, position_ids_data, flag_sin_cos, - flag_position_ids, index, seq_len, num_heads, diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 8c142d4757b552..a06236acaad619 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -30,12 +30,12 @@ def fused_rotary_position_embedding( Fused rotary position embedding. Args: - q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - k (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. - cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. - position_ids (optional|Tensor): The input tensor. The data type is int64. The shape if position_ids must be [batch_size, seq_len]. + q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + k (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + v (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + sin (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of sin must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2. + cos (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of cos must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2. + position_ids (Tensor, optional): The input tensor. The data type is int64. The shape of position_ids must be [batch_size, seq_len]. use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True. Returns: diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 1950a950ebc6e1..de6355d56a5ee6 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -278,6 +278,7 @@ def test_fused_rope_position_ids(self): f_fw, f_bw = self.get_forward_backward( fused_rotary_position_embedding, seed=self.seed, + flag=True, position_ids=position_ids, ) for i in range(len(p_fw)): From b3193753a3ef4f908bf0ecd7bce74e1a6ede4af8 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Thu, 31 Aug 2023 13:48:36 +0000 Subject: [PATCH 5/8] modified examples about fused_rope --- .../fused_rotary_position_embedding.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index a06236acaad619..fe591a6be48fd5 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -50,16 +50,29 @@ def fused_rotary_position_embedding( import paddle from paddle.incubate.nn.functional import fused_rotary_position_embedding - q = paddle.randn([1, 1, 4, 10], dtype='float16') - k = paddle.randn([1, 1, 4, 10], dtype='float16') - v = paddle.randn([1, 1, 4, 10], dtype='float16') - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v) + # batch_size = 2 + # seq_len = 8 + # num_heads = 2 + # head_dim = 10 - x = paddle.randn([1, 1, 1, 10], dtype='float16') - y = paddle.randn([1, 1, 1, 10], dtype='float16') + # q, k, v: [batch_size, seq_len, num_heads, head_dim] + q = paddle.randn([2, 8, 2, 10], dtype='float16') + k = paddle.randn([2, 8, 2, 10], dtype='float16') + v = paddle.randn([2, 8, 2, 10], dtype='float16') + + # sin, cos: [1, seq_len, 1, head_dim] + x = paddle.randn([1, 8, 1, 10], dtype='float16') + y = paddle.randn([1, 8, 1, 10], dtype='float16') sin = paddle.sin(x) cos = paddle.cos(y) - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos) + + # position_ids: [batch_size, seq_len] + position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64') + + # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] + out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) + print(out_q.shape) + # [2, 8, 2, 10] """ if in_dynamic_mode(): return _C_ops.fused_rotary_position_embedding( From 7620f7cdf02aa7205034723ee37e50b0edb7c56b Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Fri, 1 Sep 2023 06:36:36 +0000 Subject: [PATCH 6/8] modified examples according to comment --- .../fused_rotary_position_embedding.py | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index fe591a6be48fd5..3fd80ffbaaa802 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -46,33 +46,44 @@ def fused_rotary_position_embedding( .. code-block:: python - # required: gpu - import paddle - from paddle.incubate.nn.functional import fused_rotary_position_embedding - - # batch_size = 2 - # seq_len = 8 - # num_heads = 2 - # head_dim = 10 - - # q, k, v: [batch_size, seq_len, num_heads, head_dim] - q = paddle.randn([2, 8, 2, 10], dtype='float16') - k = paddle.randn([2, 8, 2, 10], dtype='float16') - v = paddle.randn([2, 8, 2, 10], dtype='float16') - - # sin, cos: [1, seq_len, 1, head_dim] - x = paddle.randn([1, 8, 1, 10], dtype='float16') - y = paddle.randn([1, 8, 1, 10], dtype='float16') - sin = paddle.sin(x) - cos = paddle.cos(y) - - # position_ids: [batch_size, seq_len] - position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64') - - # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) - print(out_q.shape) - # [2, 8, 2, 10] + >>> # required: gpu + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> from paddle.incubate.nn.functional import fused_rotary_position_embedding + + >>> # batch_size = 2 + >>> # seq_len = 2 + >>> # num_heads = 2 + >>> # head_dim = 2 + + >>> paddle.seed(1024) + + >>> # q, k, v: [batch_size, seq_len, num_heads, head_dim] + >>> q = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> k = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> v = paddle.randn([2, 2, 2, 2], dtype='float16') + + >>> # sin, cos: [1, seq_len, 1, head_dim] + >>> x = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> y = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> sin = paddle.sin(x) + >>> cos = paddle.cos(y) + + >>> # position_ids: [batch_size, seq_len] + >>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64') + + >>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] + >>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) + >>> print(out_q) + Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True, + [[[[-0.54931641, 0.64990234], + [-1.08691406, 1.18261719]], + [[ 0.57812500, 0.11749268], + [-0.63281250, 0.15551758]]], + [[[-0.77050781, 0.07733154], + [-0.73730469, -0.16735840]], + [[ 0.07116699, -0.90966797], + [-0.03628540, -0.20202637]]]]) """ if in_dynamic_mode(): return _C_ops.fused_rotary_position_embedding( From 7a56d1c883ba4936dec360bb18eddf39a7732305 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Fri, 1 Sep 2023 11:35:24 +0000 Subject: [PATCH 7/8] add set_device in examples --- .../incubate/nn/functional/fused_rotary_position_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 3fd80ffbaaa802..826edaa956eed0 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -51,6 +51,8 @@ def fused_rotary_position_embedding( >>> import paddle >>> from paddle.incubate.nn.functional import fused_rotary_position_embedding + >>> paddle.device.set_device('gpu') + >>> # batch_size = 2 >>> # seq_len = 2 >>> # num_heads = 2 From 8619245b98eca78ddcde768bb4d69838e43c3de5 Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Mon, 4 Sep 2023 00:29:01 +0800 Subject: [PATCH 8/8] Update fused_rotary_position_embedding.py --- .../incubate/nn/functional/fused_rotary_position_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 826edaa956eed0..f68dfb1dcd53f9 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -58,7 +58,7 @@ def fused_rotary_position_embedding( >>> # num_heads = 2 >>> # head_dim = 2 - >>> paddle.seed(1024) + >>> paddle.seed(1204) >>> # q, k, v: [batch_size, seq_len, num_heads, head_dim] >>> q = paddle.randn([2, 2, 2, 2], dtype='float16')