Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom vjp trait #57106

Merged
merged 40 commits into from
Sep 13, 2023
Merged

Conversation

lxd-cumt
Copy link
Contributor

@lxd-cumt lxd-cumt commented Sep 8, 2023

PR types

New features

PR changes

Others

Description

add custom_vjp trait to support checking whether an op has custom vjp

For example, gelu op has custom vjp rules and is added into custom_vjp list, as follows,

PRIM_VJP = ['divide_grad', 'sum_grad']  # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad']  # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

Therefore, for gelu op, CustomVjpTrait is automatically added into it, as follows,

class GeluOp : public ir::Op<GeluOp,paddle::dialect::OpYamlInfoInterface,paddle::dialect::InferMetaInterface,paddle::dialect::VjpInterface,paddle::dialect::CustomVjpTrait>

Finally, when calling has_custom_vjp(gelu_op), it will return True, representing that the custom vjp rules for gelu are currently defined.

Pcard-66975

@paddle-bot
Copy link

paddle-bot bot commented Sep 8, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Charles-hit
Copy link
Contributor

PR描述最好可以贴一个生成出来的例子,比如gelu


namespace paddle {
namespace dialect {
class CustomVjpTrait : public ir::OpTraitBase<CustomVjpTrait> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain custom vjp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

@lxd-cumt
Copy link
Contributor Author

lxd-cumt commented Sep 8, 2023

PR描述最好可以贴一个生成出来的例子,比如gelu

done,thx

Charles-hit
Charles-hit previously approved these changes Sep 11, 2023
return newir_program


class TestCustomVjpTrait(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测需要加一个反例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okk,下个PR加入

# import from paddle/fluid/primitive/code_gen/gen.py
sys.path.append(
str(pathlib.Path(__file__).resolve().parents[3] / 'primitive/codegen')
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不可以使用相对路径导入的方式么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前缺少__init__.py,使用相对路径的话需要新增才可以;
python代码规范里面也提到:

Absolute imports are recommended, as they are usually more readable and tend to be better behaved 
(or at least give better error messages) if the import system is incorrectly configured .

所以使用了绝对路径。

zhangbo9674
zhangbo9674 previously approved these changes Sep 11, 2023
@cyber-pioneer cyber-pioneer merged commit fec77c2 into PaddlePaddle:develop Sep 13, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* test prim custom vjp in New IR

* add a new CustomVjpTrait to represent whether an op has custom vjp

* add has_custom_vjp_op_list to represent ops that have custom vjp

* parse has_custom_vjp_op_list and autogen CustomVjpTrait for those ops

* add pybind to support checking whether an op has custom vjp in python level

* add test

* add test for add op custom vjp

* add pybind to support checking whether an op has custom vjp in python level

* fix bugs

* polish code

* fix bugs

* generate custom_vjp trait based on op list from gen.py

* delete has_custom_vjp_op_list

* fix bugs

* use currently defined list CUSTOM_VJP and VJP_COMPS rather than define new list

* fix ctest

* fix bugs

* divide vjp into prim_vjp and custom vjp

* add code comments

* add code comments

* polish codes

* polish code comments

* polish codes

* fix bugs

* add code comments

* fix bugs

* add custom vjp trait support in new folder

* fix bugs

* add another example for unit testing

* fix bugs

---------

Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants