Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix
Browse files Browse the repository at this point in the history
xingmingyyj committed Jan 3, 2024
1 parent 2d692a5 commit d5258c8
Showing 3 changed files with 13 additions and 1 deletion.
1 change: 0 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
@@ -1416,7 +1416,6 @@
output: Tensor(param_out), Tensor(squared_accum_out), Tensor(linear_accum_out)
infer_meta:
func: FtrlInferMeta
param: [param, grad, learning_rate]
kernel:
func: ftrl
data_type: param
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
@@ -1529,8 +1529,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
}

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out) {
@@ -1558,8 +1563,11 @@ void FtrlInferMeta(const MetaTensor& param,
common::product(lr_dim)));

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());
squared_accum_out->set_dims(param_dim);
squared_accum_out->set_dtype(param.dtype());
linear_accum_out->set_dims(param_dim);
linear_accum_out->set_dtype(param.dtype());
}

void FusedBatchNormActInferMeta(const MetaTensor& x,
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
@@ -309,8 +309,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* out);

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out);

0 comments on commit d5258c8

Please sign in to comment.