-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St] support train step in to_static #51693
Conversation
2. remote double_grad in Dy2static. (test_grad.py is OK)
… train-step-siguremo
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
python/paddle/fluid/dygraph/base.py
Outdated
is_persistable = len(tensor.shape) > 0 | ||
# non-persistable. See case of `dropout_state` in lstm api. | ||
is_persistable = True | ||
if "dropout_state" in tensor.name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
小建议:这个Name可以写成一个常量变量,而且最好加一些特殊标记,能够被RNN和动转静模块都感知,方便后续推动RNN kernel升级后统一删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已使用存放在 dygraph.base
里的 NON_PERSISTABLE_VAR_NAME_SUFFIX
来进行标识
python/paddle/fluid/framework.py
Outdated
@@ -27,6 +27,7 @@ | |||
import multiprocessing | |||
import sys | |||
import logging | |||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接导入顶层 paddle 可能会有circle import 的风险?推荐使用 relative import 的形式。从规范上来讲,底层module 不建议导入顶层 module。但目前看框架里似乎没有遵循这个规范,挺多地方都直接import paddle了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里主要是为了导入 .backward.append_backward
而引入的,但使用 from . import backward
反而会触发 circle import,所以不得不将这个 import 放进函数内部了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -66,6 +66,8 @@ set_tests_properties(test_transformer PROPERTIES TIMEOUT 200) | |||
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120) | |||
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120) | |||
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120) | |||
set_tests_properties(test_train_step_resnet18_sgd PROPERTIES TIMEOUT 1200) | |||
set_tests_properties(test_train_step_resnet18_adam PROPERTIES TIMEOUT 1200) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个单测的时间确实需要这么久的时间吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,试过 900 不太行,就改成 1200 了,不过这个目前只在 PR-CI-Windows-OPENBLAS 会花这么多时间,设为 900 的时候其他流水线都过了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已取消这里的 TIMEOUT 设置,暂时在 CPU disable 掉这两个单测,将会在以后性能问题解决后加回来
UPDATE:
取消不太行,GPU 也过不了,暂时降低时间为 120,看看能不能过,不能过的话稍微提高到合适的时间
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测执行时间设置
PR types
New features
PR changes
Others
Describe
基于 #51543 的完善,
依赖于 #51411(已 merge)当前 PR revert 掉 #51562 里的 https://github.com/PaddlePaddle/Paddle/pull/51562/files#diff-0417a927e0148c22ecb722f950e2f9704d6e899e9899521f0a269b173ceb2de2 关于 lr 的 0D 修改后才能跑通,需要确定影响已解决本 PR 主要修改如下内容:
为了支持整图训练,我们需要让 Tensor.backward(组网阶段是 Variable.backward)和 Optimizer.step 都支持静态图分支,因此分别对这两个方法做如下修改:
Variable.backward
挂载的 grad 属性,利用optimizer.apply_gradients
API 来更新参数,该步调用的一些会初始化 lr 等的一些步骤将会利用dygraph_guard_if_declarative
临时回到动态图下执行。此外还做了额外的修改如下:
dropout_state
的 name,以标识dropout_state
,替代原有的使用 0D 来标识,进而导致 0D 的 lr 会被过滤掉的问题用法:
TODOs
[Optional] 寻找优雅的 _set_bool_attr 方式(相关逻辑已经被 [Dy2St]Fix clone for test state problem #51966 修改)PCard-66972