Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【NewIR】add add_n_grad and split_with_num_grad (Split op ) #56873

Merged

Conversation

xiaoguoguo626807
Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 commented Sep 1, 2023

PR types

Bug fixes

PR changes

others

Description

pcard-67164
基于#57218 pr
1.手写add-n反向相关代码,add_n 反向没有op,在api层直接处理;add_ninfermeta 没有设置输出类型;add_n 手写增加stop_gradient 推导。
2.split_with_num反向没有op,复用split_grad, api层手写直接调用split_grad op;
3.开启test_sum_op add_n 全部optest;
add_n _optest 修复:
image
splitbf16修复:
image

@paddle-bot
Copy link

paddle-bot bot commented Sep 1, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -28,6 +28,15 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) {
return combine_op.out();
}

std::vector<pir::OpResult> add_n_grad(std::vector<pir::OpResult> inputs,
pir::OpResult out_grad) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pir::OpResult -> pir::Value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -76,5 +85,23 @@ pir::OpResult embedding_grad(pir::Value x,
}
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么需要手写一个split_with_num_grad的api呢?另外yaml里split_with_num_grad invoke的是concat,为什么不去调用concat op呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对齐split_grad, vjp 需要调用该api, invoke的复用在build函数中手写,由于有split_grad op 不需要加split_with_num

lanxianghit
lanxianghit previously approved these changes Sep 19, 2023
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for 'check_new_ir' in test

@xiaoguoguo626807 xiaoguoguo626807 merged commit 0029a24 into PaddlePaddle:develop Sep 19, 2023
27 checks passed
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, need fix commnet in next PR

return split_grad_op.result(0);
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::OpResult split_with_num_grad(std::vector<pir::Value>& out_grad,

虽然 pir::Value 的构造成本很低,但这里还是建议传入vector<>&,如有必要,也要加上const 限定符

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在vjp处处理直接调用了invoke的api, 如果没有其他直接调用api 的情况此处后续会删除

@@ -25,6 +25,9 @@ namespace dialect {

pir::OpResult builtin_combine(const std::vector<pir::Value>& x);

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

结合代码生成同时修改

pir::CombineOp combine_op_obj = op_obj.inputs()
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<pir::CombineOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于dyn_cast的使用,建议要check下,因为可能会有空指针导致段错误,导致问题排查比较难

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr修复

@@ -18,6 +18,7 @@
#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要include optional头文件,新增代码里并没有用到?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr删除

Copy link
Contributor Author

@xiaoguoguo626807 xiaoguoguo626807 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment 修复 #57520

@xiaoguoguo626807 xiaoguoguo626807 deleted the split_op_genbug branch September 20, 2023 03:19
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
…le#56873)

* add reference of lbfgs

* add reference of lbfgs

* tmp

* split gen modify

* fix conflict

* add split

* fix bug

* fix bug

* test split

* add meta tensor

* refine code

* fix bug

* fix bug

* fix comflict

* Call _C_ops.sum in new ir

* modify concat kernel choose

* modify ci

* modify sum zero_dim optest

* modify split_with_num api

* modify split -1

* modify split test

* fix bug

* xxx

* delete extra modify

* add add_n

* tmp

* add split_with_num_grad

* modify split grad num bug

* modify ci

* modify ci

* clear code

* modify

* recover

* add add_n stop_gradient infer

* modify opreslut to value

* fix conflict

* recover to aviod conflict

* recover to aviod conflict

* modify opreslut to value

* recover complex tanh

* modify add_n optest

* skip bfp16

* modify split bf16

* fix conflict

* delete print

---------

Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: 0x45f <wangzhen45@baidu.com>
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…le#56873)

* add reference of lbfgs

* add reference of lbfgs

* tmp

* split gen modify

* fix conflict

* add split

* fix bug

* fix bug

* test split

* add meta tensor

* refine code

* fix bug

* fix bug

* fix comflict

* Call _C_ops.sum in new ir

* modify concat kernel choose

* modify ci

* modify sum zero_dim optest

* modify split_with_num api

* modify split -1

* modify split test

* fix bug

* xxx

* delete extra modify

* add add_n

* tmp

* add split_with_num_grad

* modify split grad num bug

* modify ci

* modify ci

* clear code

* modify

* recover

* add add_n stop_gradient infer

* modify opreslut to value

* fix conflict

* recover to aviod conflict

* recover to aviod conflict

* modify opreslut to value

* recover complex tanh

* modify add_n optest

* skip bfp16

* modify split bf16

* fix conflict

* delete print

---------

Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: 0x45f <wangzhen45@baidu.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants