From d5258c8f6458dce7393832cdb7dca5a0da8a069c Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 3 Jan 2024 02:21:25 +0000 Subject: [PATCH] fix --- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 1 - paddle/phi/infermeta/multiary.cc | 8 ++++++++ paddle/phi/infermeta/multiary.h | 5 +++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 9a288fb683730..e6803d1086a3f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1416,7 +1416,6 @@ output: Tensor(param_out), Tensor(squared_accum_out), Tensor(linear_accum_out) infer_meta: func: FtrlInferMeta - param: [param, grad, learning_rate] kernel: func: ftrl data_type: param diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 798beff23bd1b..45d28e0291b8a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1529,8 +1529,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps, } void FtrlInferMeta(const MetaTensor& param, + const MetaTensor& squared_accumulator, + const MetaTensor& linear_accumulator, const MetaTensor& grad, const MetaTensor& learning_rate, + float l1, + float l2, + float lr_power, MetaTensor* param_out, MetaTensor* squared_accum_out, MetaTensor* linear_accum_out) { @@ -1558,8 +1563,11 @@ void FtrlInferMeta(const MetaTensor& param, common::product(lr_dim))); param_out->set_dims(param_dim); + param_out->set_dtype(param.dtype()); squared_accum_out->set_dims(param_dim); + squared_accum_out->set_dtype(param.dtype()); linear_accum_out->set_dims(param_dim); + linear_accum_out->set_dtype(param.dtype()); } void FusedBatchNormActInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index ed28db5740d0f..3dff661dabd55 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -309,8 +309,13 @@ void EditDistanceInferMeta(const MetaTensor& hyps, MetaTensor* out); void FtrlInferMeta(const MetaTensor& param, + const MetaTensor& squared_accumulator, + const MetaTensor& linear_accumulator, const MetaTensor& grad, const MetaTensor& learning_rate, + float l1, + float l2, + float lr_power, MetaTensor* param_out, MetaTensor* squared_accum_out, MetaTensor* linear_accum_out);