-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【NewIR】add add_n_grad and split_with_num_grad (Split op ) #56873
Changes from 74 commits
2c0166c
37883b2
185d30b
92e5303
221b70c
aba6f0e
bce9b3b
c2341a5
4d30fdd
c8f5864
805ceb2
008a8b2
c619c4e
23b37ed
9640576
000a8db
03ec67b
06def9f
d9f20d0
c027bbf
9df5940
3924ce3
d3805c7
4405f3c
a2fa7be
fa23c5d
7273556
24115d0
b8e98c5
a516693
237c493
fff986c
49a0678
9d3c92b
addc342
697c11f
ecc3d21
1bc9720
0848775
56c8666
f29b9d6
3c1c6aa
372f43f
8338414
4d2b8c7
7fac313
224bc4a
a6a9ee4
b10214f
c3cbb6d
050b58c
0bd42a8
dccfe38
8cb9782
6ed24e3
ddfda46
94732c5
54aad03
0ea16a5
a93a8a0
f0a9294
77045b2
e8b0dea
692c87f
5bc2dc7
283fea2
b0dc6d3
6612f3a
0882578
cf910ac
17ad8a6
8f0bd46
825fea4
f5fc3a4
be2ffec
a6e1ff2
8703219
cf64265
e4e57d4
05c4023
df0dd36
eaab743
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,6 +28,15 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) { | |||||
return combine_op.out(); | ||||||
} | ||||||
|
||||||
std::vector<pir::OpResult> add_n_grad(std::vector<pir::OpResult> inputs, | ||||||
pir::OpResult out_grad) { | ||||||
std::vector<pir::OpResult> inputs_grad; | ||||||
for (size_t i = 0; i < inputs.size(); i++) { | ||||||
inputs_grad.push_back(out_grad); | ||||||
} | ||||||
return inputs_grad; | ||||||
} | ||||||
|
||||||
pir::OpResult zeros_like(pir::Value x, | ||||||
phi::DataType dtype, | ||||||
const Place& place) { | ||||||
|
@@ -76,5 +85,23 @@ pir::OpResult embedding_grad(pir::Value x, | |||||
} | ||||||
} | ||||||
|
||||||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么需要手写一个split_with_num_grad的api呢?另外yaml里split_with_num_grad invoke的是concat,为什么不去调用concat op呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对齐split_grad, vjp 需要调用该api, invoke的复用在build函数中手写,由于有split_grad op 不需要加split_with_num |
||||||
auto out_grad_combine_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad); | ||||||
paddle::dialect::SplitGradOp split_grad_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>( | ||||||
out_grad_combine_op.out(), axis); | ||||||
return split_grad_op.result(0); | ||||||
} | ||||||
|
||||||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
虽然 pir::Value 的构造成本很低,但这里还是建议传入vector<>&,如有必要,也要加上const 限定符 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个在vjp处处理直接调用了invoke的api, 如果没有其他直接调用api 的情况此处后续会删除 |
||||||
pir::Value axis) { | ||||||
auto out_grad_combine_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad); | ||||||
paddle::dialect::SplitGradOp split_grad_op = | ||||||
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>( | ||||||
out_grad_combine_op.out(), axis); | ||||||
return split_grad_op.result(0); | ||||||
} | ||||||
} // namespace dialect | ||||||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,9 @@ namespace dialect { | |
|
||
pir::OpResult builtin_combine(const std::vector<pir::Value>& x); | ||
|
||
std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 结合代码生成同时修改 |
||
pir::Value out_grad); | ||
|
||
pir::OpResult zeros_like(pir::Value x, | ||
phi::DataType dtype = phi::DataType::UNDEFINED, | ||
const Place& place = {}); | ||
|
@@ -41,5 +44,9 @@ pir::OpResult embedding_grad(pir::Value x, | |
int64_t padding_idx = -1, | ||
bool sparse = false); | ||
|
||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis); | ||
|
||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, | ||
pir::OpResult axis); | ||
} // namespace dialect | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
#include <vector> | ||
|
||
#include "paddle/phi/api/include/tensor.h" | ||
#include "paddle/utils/optional.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么要include optional头文件,新增代码里并没有用到? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下个pr删除 |
||
|
||
namespace paddle { | ||
namespace primitive { | ||
|
@@ -28,6 +29,10 @@ using Scalar = paddle::experimental::Scalar; | |
using IntArray = paddle::experimental::IntArray; | ||
using DataType = phi::DataType; | ||
|
||
template <typename T> | ||
std::vector<Tensor> add_n_grad(const std::vector<Tensor>& x, | ||
const Tensor& out_grad); | ||
|
||
} // namespace backend | ||
} // namespace primitive | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir::OpResult -> pir::Value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done