Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【NewIR】add add_n_grad and split_with_num_grad (Split op ) #56873

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
2c0166c
add reference of lbfgs
xiaoguoguo626807 Aug 11, 2023
37883b2
add reference of lbfgs
xiaoguoguo626807 Aug 11, 2023
185d30b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 18, 2023
92e5303
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 24, 2023
221b70c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 25, 2023
aba6f0e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 28, 2023
bce9b3b
tmp
xiaoguoguo626807 Aug 30, 2023
c2341a5
fix conflict
xiaoguoguo626807 Aug 30, 2023
4d30fdd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 31, 2023
c8f5864
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 1, 2023
805ceb2
split gen modify
xiaoguoguo626807 Sep 1, 2023
008a8b2
fix conflict
xiaoguoguo626807 Sep 4, 2023
c619c4e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 4, 2023
23b37ed
add split
zhangbo9674 Sep 4, 2023
9640576
Merge branch 'develop', commit 'refs/pull/56924/head' of https://gith…
xiaoguoguo626807 Sep 4, 2023
000a8db
fix bug
zhangbo9674 Sep 4, 2023
03ec67b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Sep 4, 2023
06def9f
fix bug
zhangbo9674 Sep 4, 2023
d9f20d0
test split
xiaoguoguo626807 Sep 5, 2023
c027bbf
Merge branch 'develop', commit 'refs/pull/56924/head' of https://gith…
xiaoguoguo626807 Sep 5, 2023
9df5940
add meta tensor
zhangbo9674 Sep 5, 2023
3924ce3
refine code
zhangbo9674 Sep 5, 2023
d3805c7
fix bug
zhangbo9674 Sep 5, 2023
4405f3c
fix bug
zhangbo9674 Sep 5, 2023
a2fa7be
fix comflict
xiaoguoguo626807 Sep 5, 2023
fa23c5d
Merge branch 'develop', commit 'refs/pull/56973/head' of https://gith…
xiaoguoguo626807 Sep 5, 2023
7273556
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 5, 2023
24115d0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
b8e98c5
Call _C_ops.sum in new ir
0x45f Sep 6, 2023
a516693
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
237c493
fix conflict
xiaoguoguo626807 Sep 6, 2023
fff986c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
49a0678
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
9d3c92b
modify concat kernel choose
xiaoguoguo626807 Sep 6, 2023
addc342
modify ci
xiaoguoguo626807 Sep 7, 2023
697c11f
modify sum zero_dim optest
xiaoguoguo626807 Sep 7, 2023
ecc3d21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
1bc9720
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
0848775
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Sep 7, 2023
56c8666
modify split_with_num api
xiaoguoguo626807 Sep 7, 2023
f29b9d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
3c1c6aa
fix conflict
xiaoguoguo626807 Sep 7, 2023
372f43f
fix conflict
xiaoguoguo626807 Sep 7, 2023
8338414
modify split -1
xiaoguoguo626807 Sep 8, 2023
4d2b8c7
fix conflict
xiaoguoguo626807 Sep 8, 2023
7fac313
modify split test
xiaoguoguo626807 Sep 11, 2023
224bc4a
fix conflict
xiaoguoguo626807 Sep 11, 2023
a6a9ee4
fix bug
xiaoguoguo626807 Sep 11, 2023
b10214f
xxx
xiaoguoguo626807 Sep 11, 2023
c3cbb6d
fix conflict
xiaoguoguo626807 Sep 11, 2023
050b58c
delete extra modify
xiaoguoguo626807 Sep 11, 2023
0bd42a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 11, 2023
dccfe38
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 11, 2023
8cb9782
fix conflict
xiaoguoguo626807 Sep 12, 2023
6ed24e3
add add_n
xiaoguoguo626807 Sep 13, 2023
ddfda46
tmp
xiaoguoguo626807 Sep 14, 2023
94732c5
fix conflict
xiaoguoguo626807 Sep 14, 2023
54aad03
add split_with_num_grad
xiaoguoguo626807 Sep 14, 2023
0ea16a5
fix conflict
xiaoguoguo626807 Sep 14, 2023
a93a8a0
fix conflict
xiaoguoguo626807 Sep 14, 2023
f0a9294
modify split grad num bug
xiaoguoguo626807 Sep 15, 2023
77045b2
fix conflict
xiaoguoguo626807 Sep 15, 2023
e8b0dea
modify ci
xiaoguoguo626807 Sep 16, 2023
692c87f
modify ci
xiaoguoguo626807 Sep 16, 2023
5bc2dc7
modify ci
xiaoguoguo626807 Sep 17, 2023
283fea2
clear code
xiaoguoguo626807 Sep 17, 2023
b0dc6d3
modify
xiaoguoguo626807 Sep 17, 2023
6612f3a
recover
xiaoguoguo626807 Sep 18, 2023
0882578
add add_n stop_gradient infer
xiaoguoguo626807 Sep 18, 2023
cf910ac
modify opreslut to value
xiaoguoguo626807 Sep 18, 2023
17ad8a6
modify opreslut to value
xiaoguoguo626807 Sep 18, 2023
8f0bd46
fix conflict
xiaoguoguo626807 Sep 18, 2023
825fea4
recover to aviod conflict
xiaoguoguo626807 Sep 18, 2023
f5fc3a4
recover to aviod conflict
xiaoguoguo626807 Sep 18, 2023
be2ffec
modify opreslut to value
xiaoguoguo626807 Sep 18, 2023
a6e1ff2
recover complex tanh
xiaoguoguo626807 Sep 18, 2023
8703219
modify add_n optest
xiaoguoguo626807 Sep 18, 2023
cf64265
skip bfp16
xiaoguoguo626807 Sep 18, 2023
e4e57d4
modify split bf16
xiaoguoguo626807 Sep 19, 2023
05c4023
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 19, 2023
df0dd36
fix conflict
xiaoguoguo626807 Sep 19, 2023
eaab743
delete print
xiaoguoguo626807 Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'bool': 'pir::BoolAttribute',
}

PD_MANUAL_OP_LIST = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'}
PD_MANUAL_OP_LIST = {
'add_n',
'add_n_',
'add_n_with_kernel',
'split_grad',
}


def to_phi_and_fluid_op_name(op_item):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"add",
"concat",
"split",
"split_with_num",
"gelu",
"matmul",
"erf",
Expand All @@ -45,6 +46,7 @@
'layer_norm',
'reshape',
'cast',
"scale",
'softmax',
'silu',
'elementwise_pow',
Expand All @@ -53,7 +55,6 @@
'slice',
'transpose',
'slice_double',
'scale',
]
vjp_interface_implementation_gen_op_list = [
"tanh",
Expand All @@ -63,6 +64,7 @@
"add",
"concat",
"split",
"split_with_num",
"gelu",
"matmul",
"erf",
Expand All @@ -77,6 +79,7 @@
'layer_norm',
'reshape',
'cast',
"scale",
'softmax',
'silu',
'elementwise_pow',
Expand All @@ -85,5 +88,4 @@
'slice',
'transpose',
'slice_double',
'scale',
]
30 changes: 30 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) {
return combine_op.out();
}

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
pir::Value out_grad) {
std::vector<pir::OpResult> inputs_grad;
for (size_t i = 0; i < inputs.size(); i++) {
paddle::dialect::ScaleOp scale_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::ScaleOp>(
out_grad, 1.0, 0.0, true);
inputs_grad.push_back(scale_op.result(0));
}
return inputs_grad;
}

pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype,
const Place& place) {
Expand Down Expand Up @@ -76,5 +88,23 @@ pir::OpResult embedding_grad(pir::Value x,
}
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么需要手写一个split_with_num_grad的api呢?另外yaml里split_with_num_grad invoke的是concat,为什么不去调用concat op呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对齐split_grad, vjp 需要调用该api, invoke的复用在build函数中手写,由于有split_grad op 不需要加split_with_num

auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
out_grad_combine_op.out(), axis);
return split_grad_op.result(0);
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::OpResult split_with_num_grad(std::vector<pir::Value>& out_grad,

虽然 pir::Value 的构造成本很低,但这里还是建议传入vector<>&,如有必要,也要加上const 限定符

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在vjp处处理直接调用了invoke的api, 如果没有其他直接调用api 的情况此处后续会删除

pir::Value axis) {
auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
out_grad_combine_op.out(), axis);
return split_grad_op.result(0);
}
} // namespace dialect
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace dialect {

pir::OpResult builtin_combine(const std::vector<pir::Value>& x);

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

结合代码生成同时修改

pir::Value out_grad);

pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});
Expand All @@ -41,5 +44,9 @@ pir::OpResult embedding_grad(pir::Value x,
int64_t padding_idx = -1,
bool sparse = false);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::Value axis);
} // namespace dialect
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -148,6 +149,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) {
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef GET_MANUAL_OP_LIST
#undef GET_MANUAL_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp

#else

#pragma once
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
Expand All @@ -36,7 +31,10 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp
namespace paddle {
namespace dialect {

class AddNOp : public pir::Op<AddNOp, OpYamlInfoInterface, InferMetaInterface> {
class AddNOp : public pir::Op<AddNOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::VjpInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.add_n"; }
Expand All @@ -51,6 +49,10 @@ class AddNOp : public pir::Op<AddNOp, OpYamlInfoInterface, InferMetaInterface> {
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};

class AddN_Op : public pir::Op<AddN_Op,
Expand Down Expand Up @@ -201,4 +203,3 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
#endif
44 changes: 44 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,49 @@ namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> AddNOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddNOp op_obj = op->dyn_cast<AddNOp>();

VLOG(6) << "Prepare inputs of add_n_grad";

pir::CombineOp combine_op_obj = op_obj.inputs()
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<pir::CombineOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于dyn_cast的使用,建议要check下,因为可能会有空指针导致段错误,导致问题排查比较难

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr修复

std::vector<Tensor> inputs;
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {
inputs.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}

Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));

VLOG(6) << "Vjp prepare Prepare attributes of add_n_grad";

VLOG(6) << "Vjp prepare call add_n's vjp inteface";

std::vector<std::vector<Tensor>> tensor_res =
primitive::add_n_vjp(inputs, out_grad, stop_gradients);

VLOG(6) << "Vjp prepare stop gradient of add_n_grad";

std::vector<std::vector<pir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->value()
.dyn_cast<pir::OpResult>();
}
}
}
return res;
}

} // namespace dialect
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要include optional头文件,新增代码里并没有用到?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr删除


namespace paddle {
namespace primitive {
Expand All @@ -28,6 +29,10 @@ using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

template <typename T>
std::vector<Tensor> add_n_grad(const std::vector<Tensor>& x,
const Tensor& out_grad);

} // namespace backend
} // namespace primitive
} // namespace paddle
22 changes: 22 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
Expand All @@ -22,6 +24,26 @@ namespace primitive {
namespace backend {

using LazyTensor = paddle::primitive::LazyTensor;
template <>
std::vector<Tensor> add_n_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad) {
std::vector<pir::Value> x_res(x.size());
std::transform(x.begin(), x.end(), x_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->value();
});
pir::Value out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())->value();
auto op_res = paddle::dialect::add_n_grad(x_res, out_grad_res);

std::vector<Tensor> x_grad(op_res.size());
std::transform(op_res.begin(),
op_res.end(),
x_grad.begin(),
[](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return x_grad;
}

} // namespace backend
} // namespace primitive
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'sum_grad',
'concat_grad',
'split_grad',
'split_with_num_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
Expand Down Expand Up @@ -97,6 +98,7 @@
'sum_grad',
'concat_grad',
'split_grad',
'split_with_num_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
Expand Down Expand Up @@ -145,6 +147,7 @@
'slice',
'layer_norm_grad',
'embedding_grad',
'sqrt',
'uniform',
]

Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,21 @@
#include "paddle/pir/core/operation.h"

namespace paddle {
namespace primitive {} // namespace primitive
namespace primitive {

std::vector<std::vector<paddle::Tensor>> add_n_vjp(
const std::vector<paddle::Tensor>& x,
const Tensor& out_grad,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg : stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
auto op_res = backend::add_n_grad<LazyTensor>(x, out_grad);
vjp_res[0] = op_res;
vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients);
return vjp_res;
}

} // namespace primitive
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ namespace paddle {
namespace primitive {

using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<paddle::Tensor>> add_n_vjp(
const std::vector<paddle::Tensor>& x,
const Tensor& out_grad,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
out->set_dims(in_dim);
}
out->share_lod(*x[0]);
out->set_dtype(x[0]->dtype());
}

// TODO(YuanRisheng) This InferMeta is used in Fluid
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
union_op_flags[i] = False
intersection_op_flags[i] = False

# some inputs in no_grad_set but its next op is effective,
# add their defining op here.
total_ops_list = list(total_ops)
for i, op in enumerate(total_ops_list):
if union_op_flags[i] is False:
for result in op.results():
if result.has_one_use():
next_op = result.first_use().owner()
if (
next_op in total_ops
and union_op_flags[total_ops_list.index(next_op)]
is True
):
union_op_flags[i] = True
else:
continue

effective_ops = [
total_ops[i] for i in range(len(total_ops)) if intersection_op_flags[i]
]
Expand Down
Loading