Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix
Browse files Browse the repository at this point in the history
xingmingyyj committed Jan 3, 2024

Unverified

This user has not yet uploaded their public signing key.
1 parent 2d692a5 commit d5258c8
Showing 3 changed files with 13 additions and 1 deletion.
1 change: 0 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
@@ -1416,7 +1416,6 @@
output: Tensor(param_out), Tensor(squared_accum_out), Tensor(linear_accum_out)
infer_meta:
func: FtrlInferMeta
param: [param, grad, learning_rate]
kernel:
func: ftrl
data_type: param
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
@@ -1529,8 +1529,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
}

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out) {
@@ -1558,8 +1563,11 @@ void FtrlInferMeta(const MetaTensor& param,
common::product(lr_dim)));

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());
squared_accum_out->set_dims(param_dim);
squared_accum_out->set_dtype(param.dtype());
linear_accum_out->set_dims(param_dim);
linear_accum_out->set_dtype(param.dtype());
}

void FusedBatchNormActInferMeta(const MetaTensor& x,
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
@@ -309,8 +309,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* out);

void FtrlInferMeta(const MetaTensor& param,
const MetaTensor& squared_accumulator,
const MetaTensor& linear_accumulator,
const MetaTensor& grad,
const MetaTensor& learning_rate,
float l1,
float l2,
float lr_power,
MetaTensor* param_out,
MetaTensor* squared_accum_out,
MetaTensor* linear_accum_out);

0 comments on commit d5258c8

Please sign in to comment.