Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在fused_rope算子中增加rotate_half实现方式 #56401

Merged
merged 8 commits into from
Sep 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/fused_backward.yaml
Original file line number Diff line number Diff line change
@@ -17,10 +17,10 @@
support_dygraph_mode : true

- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
@@ -148,11 +148,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index

- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for add inputs

output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v,sin,cos, out_k, out_v
optional : k, v, sin, cos, position_ids, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
@@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
@@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
@@ -3848,6 +3848,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
@@ -769,6 +769,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
48 changes: 36 additions & 12 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -27,9 +27,11 @@ template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
bool use_neox_rotary_style,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
@@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx,
sin_cos_data[1] = cos->data<T>();

flag_sin_cos = true;

if (position_ids.get_ptr()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该可以直接 if (position_ids) 的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个pr修改

position_ids_data = position_ids->data<int64_t>();
}
}

int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}

} // namespace fusion
99 changes: 78 additions & 21 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
bool use_neox_rotary_style,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;

ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
"The batch_size and num_heads of sin and cos must be 1."));
}
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));

if (position_ids.get_ptr()) {
PADDLE_ENFORCE_EQ(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] >= seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len of sin and cos must be greater than or equal to "
"this of q. The head_dim of sin and cos must be the same as this "
"of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));

auto position_ids_dims = position_ids.get_ptr()->dims();
PADDLE_ENFORCE_EQ(position_ids_dims.size(),
2,
phi::errors::InvalidArgument(
"The dims of position_ids is expected to "
"be 2, but recieved %d.",
position_ids_dims.size()));

PADDLE_ENFORCE_EQ(
(position_ids_dims[0] == batch_size &&
position_ids_dims[1] == seq_len),
true,
phi::errors::InvalidArgument(
"The batch_size and seq_len of position_ids must be the same as "
"those of q. But recieved position_ids's "
"shape is {%s}, q's shape is {%s}.",
position_ids_dims,
q.dims()));

position_ids_data = position_ids->data<int64_t>();
} else {
PADDLE_ENFORCE_EQ(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
}

sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
}

int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得kernel名字改成:

VectorizedFusedNeoxRopeKernel 是不是好点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个pr修改

<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}
} // namespace fusion
} // namespace phi
Loading