-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【NewIR】add add_n_grad and split_with_num_grad (Split op ) #56873
【NewIR】add add_n_grad and split_with_num_grad (Split op ) #56873
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… dev/refine_mutable_attr_split
@@ -28,6 +28,15 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) { | |||
return combine_op.out(); | |||
} | |||
|
|||
std::vector<pir::OpResult> add_n_grad(std::vector<pir::OpResult> inputs, | |||
pir::OpResult out_grad) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir::OpResult -> pir::Value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -76,5 +85,23 @@ pir::OpResult embedding_grad(pir::Value x, | |||
} | |||
} | |||
|
|||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要手写一个split_with_num_grad的api呢?另外yaml里split_with_num_grad invoke的是concat,为什么不去调用concat op呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对齐split_grad, vjp 需要调用该api, invoke的复用在build函数中手写,由于有split_grad op 不需要加split_with_num
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for 'check_new_ir' in test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, need fix commnet in next PR
return split_grad_op.result(0); | ||
} | ||
|
||
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, | |
pir::OpResult split_with_num_grad(std::vector<pir::Value>& out_grad, |
虽然 pir::Value 的构造成本很低,但这里还是建议传入vector<>&,如有必要,也要加上const 限定符
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在vjp处处理直接调用了invoke的api, 如果没有其他直接调用api 的情况此处后续会删除
@@ -25,6 +25,9 @@ namespace dialect { | |||
|
|||
pir::OpResult builtin_combine(const std::vector<pir::Value>& x); | |||
|
|||
std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
结合代码生成同时修改
pir::CombineOp combine_op_obj = op_obj.inputs() | ||
.dyn_cast<pir::OpResult>() | ||
.owner() | ||
->dyn_cast<pir::CombineOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于dyn_cast的使用,建议要check下,因为可能会有空指针导致段错误,导致问题排查比较难
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个pr修复
@@ -18,6 +18,7 @@ | |||
#include <vector> | |||
|
|||
#include "paddle/phi/api/include/tensor.h" | |||
#include "paddle/utils/optional.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要include optional头文件,新增代码里并没有用到?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个pr删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment 修复 #57520
…le#56873) * add reference of lbfgs * add reference of lbfgs * tmp * split gen modify * fix conflict * add split * fix bug * fix bug * test split * add meta tensor * refine code * fix bug * fix bug * fix comflict * Call _C_ops.sum in new ir * modify concat kernel choose * modify ci * modify sum zero_dim optest * modify split_with_num api * modify split -1 * modify split test * fix bug * xxx * delete extra modify * add add_n * tmp * add split_with_num_grad * modify split grad num bug * modify ci * modify ci * clear code * modify * recover * add add_n stop_gradient infer * modify opreslut to value * fix conflict * recover to aviod conflict * recover to aviod conflict * modify opreslut to value * recover complex tanh * modify add_n optest * skip bfp16 * modify split bf16 * fix conflict * delete print --------- Co-authored-by: zhangbo9674 <zhangbo54@baidu.com> Co-authored-by: 0x45f <wangzhen45@baidu.com>
…le#56873) * add reference of lbfgs * add reference of lbfgs * tmp * split gen modify * fix conflict * add split * fix bug * fix bug * test split * add meta tensor * refine code * fix bug * fix bug * fix comflict * Call _C_ops.sum in new ir * modify concat kernel choose * modify ci * modify sum zero_dim optest * modify split_with_num api * modify split -1 * modify split test * fix bug * xxx * delete extra modify * add add_n * tmp * add split_with_num_grad * modify split grad num bug * modify ci * modify ci * clear code * modify * recover * add add_n stop_gradient infer * modify opreslut to value * fix conflict * recover to aviod conflict * recover to aviod conflict * modify opreslut to value * recover complex tanh * modify add_n optest * skip bfp16 * modify split bf16 * fix conflict * delete print --------- Co-authored-by: zhangbo9674 <zhangbo54@baidu.com> Co-authored-by: 0x45f <wangzhen45@baidu.com>
PR types
Bug fixes
PR changes
others
Description
pcard-67164
基于#57218 pr
1.手写add-n反向相关代码,add_n 反向没有op,在api层直接处理;add_ninfermeta 没有设置输出类型;add_n 手写增加stop_gradient 推导。
2.split_with_num反向没有op,复用split_grad, api层手写直接调用split_grad op;
3.开启test_sum_op add_n 全部optest;
add_n _optest 修复:
splitbf16修复: