Skip to content

Commit

Permalink
【PIR OpTest Fix No.19】 fix test_ftrl_op (#60329)
Browse files Browse the repository at this point in the history
* fix test_ftrl_op

* fix
  • Loading branch information
xingmingyyj authored Jan 5, 2024
1 parent 2033381 commit 1874d1c
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
'decayed_adagrad',
'dpsgd',
'embedding_grad_sparse',
'ftrl',
'fused_batch_norm_act_',
'fused_bn_add_activation_',
'fused_elemwise_add_activation',
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,15 @@
func: dpsgd
data_type: param

- op: ftrl
args: (Tensor param, Tensor squared_accumulator, Tensor linear_accumulator, Tensor grad, Tensor learning_rate, float l1=0.0f, float l2=0.0f, float lr_power=-0.5f)
output: Tensor(param_out), Tensor(squared_accum_out), Tensor(linear_accum_out)
infer_meta:
func: FtrlInferMeta
kernel:
func: ftrl
data_type: param

- op: fused_attention
args: (Tensor x, Tensor ln_scale, Tensor ln_bias, Tensor qkv_weight, Tensor qkv_bias, Tensor cache_kv, Tensor src_mask, Tensor out_linear_weight, Tensor out_linear_bias, Tensor ln_scale_2, Tensor ln_bias_2, int num_heads, bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, float attn_dropout_rate, bool is_test, bool attn_dropout_fix_seed, int attn_dropout_seed, str attn_dropout_implementation, float dropout_rate, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon, bool add_residual, int ring_id)
output: Tensor(ln_mean), Tensor(ln_var), Tensor(ln_out), Tensor(qkv_out), Tensor(qkv_bias_out), Tensor(transpose_out_2), Tensor(qk_out), Tensor(qktv_out), Tensor(softmax_out), Tensor(attn_dropout_mask_out), Tensor(attn_dropout_out), Tensor(src_mask_out), Tensor(fmha_out), Tensor(out_linear_out), Tensor(dropout_mask_out), Tensor(ln_mean_2), Tensor(ln_var_2), Tensor(bias_dropout_residual_out), Tensor(cache_kv_out), Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const std::unordered_set<std::string> LegacyOpList = {
CBroadcast_Op::name(),
CSyncCalcStream_Op::name(),
CSyncCommStream_Op::name(),
FtrlOp::name(),
FusedElemwiseAddActivationOp::name(),
FusedElemwiseAddActivationGradOp::name(),
FusedGemmEpilogueOp::name(),
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3443,6 +3443,12 @@
inputs: {x: X}
outputs: {out: Out}

- op: ftrl
inputs:
{param: Param, squared_accumulator: SquaredAccumulator, linear_accumulator: LinearAccumulator, grad: Grad, learning_rate: LearningRate}
outputs:
{param_out: ParamOut, squared_accum_out: SquaredAccumOut, linear_accum_out: LinearAccumOut}

- op: full_batch_size_like (fill_constant_batch_size_like)
inputs:
{input: Input}
Expand Down
42 changes: 42 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,48 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum->set_dtype(DataType::FLOAT32);
}

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out) {
auto param_dim = param.dims();
PADDLE_ENFORCE_EQ(param_dim,
grad.dims(),
phi::errors::InvalidArgument(
"Two input of FTRL Op's dimension must be same, but "
"param_dim is %d, Grad is %d",
param_dim,
grad.dims()));

auto lr_dim = learning_rate.dims();
PADDLE_ENFORCE_NE(common::product(lr_dim),
0,
phi::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(common::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning Rate should be a scalar, but got %d",
common::product(lr_dim)));

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());
squared_accum_out->set_dims(param_dim);
squared_accum_out->set_dtype(param.dtype());
linear_accum_out->set_dims(param_dim);
linear_accum_out->set_dtype(param.dtype());
}

void FusedBatchNormActInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,18 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* sequencenum,
MetaTensor* out);

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out);

void FusedBatchNormActInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ test_fmax_op
test_fmin_op
test_fold_op
test_frame_op
test_ftrl_op
test_full_like_op
test_fused_attention_op
test_fused_attention_op_api
Expand Down

0 comments on commit 1874d1c

Please sign in to comment.