Skip to content

Commit

Permalink
【PIR OpTest Fix No.26】 fix test_fused_token_prune_op (#62644)
Browse files Browse the repository at this point in the history
* fix test_fused_token_prune_op

* fix codestyle

* fix codestyle2

* Update paddle/fluid/pir/dialect/operator/ir/ops.yaml

Co-authored-by: kangguangli <kangguangli@hotmail.com>

* Update paddle/phi/api/yaml/op_compat.yaml

Co-authored-by: kangguangli <kangguangli@hotmail.com>

* fix name

---------

Co-authored-by: kangguangli <kangguangli@hotmail.com>
  • Loading branch information
Eddie-Wang1120 and kangguangli authored Mar 18, 2024
1 parent edc3bfc commit fd2389f
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
'fused_elemwise_add_activation',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fused_token_prune',
'fused_dconv_drelu_dbn',
'fused_dot_product_attention',
'nce',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,14 @@
func : fused_softmax_mask_upper_triangle
backward: fused_softmax_mask_upper_triangle_grad

- op : fused_token_prune
args : (Tensor attn, Tensor x, Tensor mask, Tensor new_mask, bool keep_first_token = true, bool keep_order = false)
output : Tensor(slimmed_x), Tensor(cls_inds)
infer_meta :
func : FusedTokenPruneInferMeta
kernel:
func : fused_token_prune

- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const std::unordered_set<std::string> LegacyOpList = {
FtrlOp::name(),
FusedElemwiseAddActivationOp::name(),
FusedElemwiseAddActivationGradOp::name(),
FusedTokenPruneOp::name(),
DpsgdOp::name(),
SendV2Op::name(),
RecvV2Op::name(),
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3712,6 +3712,12 @@
outputs :
{out : Out}

- op: fused_token_prune
inputs :
{attn: Attn, x: X, mask: Mask, new_mask: NewMask}
outputs :
{slimmed_x : SlimmedX, cls_inds : CLSInds}

- op: fusion_squared_mat_sub
inputs :
x : X
Expand Down
80 changes: 80 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,86 @@ void FusedRopeInferMeta(const MetaTensor& q,
}
}

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* cls_inds) {
auto mask_dim = mask.dims();
auto attn_dim = attn.dims();
auto x_dim = x.dims();
auto new_mask_dim = new_mask.dims();

PADDLE_ENFORCE_EQ(
mask_dim.size(),
4,
phi::errors::InvalidArgument("The input mask must be 4-dimension"));
PADDLE_ENFORCE_EQ(
attn_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The input x must be 4-dimension"));
PADDLE_ENFORCE_EQ(
new_mask_dim.size(),
4,
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
PADDLE_ENFORCE_EQ(mask_dim[0],
attn_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and attn should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(mask_dim[1],
attn_dim[1],
phi::errors::InvalidArgument(
"The second dim of mask and attn should be the same"
"which is nb_head"));
PADDLE_ENFORCE_EQ(mask_dim[0],
x_dim[0],
phi::errors::InvalidArgument(
"The first dim of mask and x should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(
mask_dim[2],
mask_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(
attn_dim[2],
attn_dim[3],
phi::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
mask_dim[2],
phi::errors::InvalidArgument(
"The third dim of mask and attn should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
x_dim[1],
phi::errors::InvalidArgument(
"The third dim of mask and the second dim of attn"
"should be the same which is max seq len"));

auto bsz = mask_dim[0];
auto c = x_dim[2];
auto slim_seq_len = new_mask_dim[2];

std::vector<int64_t> slimmed_x_dims({bsz, slim_seq_len, c});
slimmed_x->set_dims(common::make_ddim(slimmed_x_dims));
slimmed_x->set_dtype(x.dtype());

std::vector<int64_t> cls_inds_dims({bsz, slim_seq_len});
cls_inds->set_dims(common::make_ddim(cls_inds_dims));
cls_inds->set_dtype(phi::DataType::INT64);
}

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,15 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);

void FusedTokenPruneInferMeta(const MetaTensor& attn,
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& new_mask,
bool keep_first_token,
bool keep_order,
MetaTensor* slimmed_x,
MetaTensor* cls_inds);

void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ test_fused_fc_elementwise_layernorm_op
test_fused_feedforward_op
test_fused_gate_attention_op
test_fused_multihead_matmul_op
test_fused_token_prune_op
test_fusion_seqexpand_concat_fc_op
test_fusion_transpose_flatten_concat_op
test_gather_nd_op
Expand Down

0 comments on commit fd2389f

Please sign in to comment.