Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] support train step in to_static #51693

Merged
merged 55 commits into from
Apr 4, 2023

Conversation

SigureMo
Copy link
Member

@SigureMo SigureMo commented Mar 15, 2023

PR types

New features

PR changes

Others

Describe

基于 #51543 的完善,依赖于 #51411(已 merge)

当前 PR revert 掉 #51562 里的 https://github.com/PaddlePaddle/Paddle/pull/51562/files#diff-0417a927e0148c22ecb722f950e2f9704d6e899e9899521f0a269b173ceb2de2 关于 lr 的 0D 修改后才能跑通,需要确定影响 已解决

本 PR 主要修改如下内容:

为了支持整图训练,我们需要让 Tensor.backward(组网阶段是 Variable.backward)和 Optimizer.step 都支持静态图分支,因此分别对这两个方法做如下修改:

  • Variable.backward 原来仅为占位用,修改为有实际行为,将计算得到的 grad 临时挂载到 Variable 的 grad 属性上;
  • Optimizer.step 原来仅支持动态图,修改为支持动转静下静态图组网,会利用之前 Variable.backward 挂载的 grad 属性,利用 optimizer.apply_gradients API 来更新参数,该步调用的一些会初始化 lr 等的一些步骤将会利用 dygraph_guard_if_declarative 临时回到动态图下执行。

此外还做了额外的修改如下:

  • 为支持动转静的 LR Scheduler 调度,在调用 run_program op 前将 lr scheduler 值同步到 lr_var 里;
  • 修改 rnn 中 dropout_state 的 name,以标识 dropout_state,替代原有的使用 0D 来标识,进而导致 0D 的 lr 会被过滤掉的问题

用法:

import random

import numpy as np

import paddle

paddle.set_device('cpu')
# paddle.set_device('gpu:1')

paddle.seed(1010)
np.random.seed(1010)
random.seed(1010)


class TinyModel(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.linear = paddle.nn.Linear(in_features=10, out_features=10)

    def forward(self, x):
        return self.linear(x)


def loss_fn(x):
    return x.mean()


def train_step(net, x, loss_fn, opt):
    out = net(x)
    loss = loss_fn(out)
    loss.backward()
    opt.step()
    opt.clear_grad()
    return loss


x = paddle.randn([100, 10])
net = TinyModel()
sgd = paddle.optimizer.SGD(0.001, parameters=net.parameters())

train_step = paddle.jit.to_static(train_step)

for i in range(10):
    train_step(net, x, loss_fn, sgd)

# Note: `lr_scheduler.step` 应该在 train_step 外调用,不支持在 train_step 内调用 

TODOs

  • 单测
    • 所有 optimizer
    • 测试 lr scheduler
  • 性能数据测试(ResNet50)并解决性能下降的问题(不阻塞合入)
  • [Optional] 寻找优雅的 _set_bool_attr 方式(相关逻辑已经被 [Dy2St]Fix clone for test state problem #51966 修改)
  • 需要解决添加单测时 CI 报错的问题,该问题目前 track 于 【Hackathon 4th 】add Trapezoid API && add Cumulative_trapezoid API #51195 (comment) ,该问题解决后 CI 应该可以全部通过(PR-CI-Py3 已经全过了)
  • 查明double_grad 的影响面,目前完全删除不会影响单测。是不是目前已经不需要了?(单测已过)
  • 支持 Adam 的 param 为 dict 的分支。(目前只支持 list)(不阻塞合入)
  • 支持 Adam 的 multi_tensor 分支??(还不是很清楚场景)(不阻塞合入)

PCard-66972

@paddle-bot
Copy link

paddle-bot bot commented Mar 15, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Aurelius84
Aurelius84 previously approved these changes Mar 29, 2023
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

is_persistable = len(tensor.shape) > 0
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if "dropout_state" in tensor.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

小建议:这个Name可以写成一个常量变量,而且最好加一些特殊标记,能够被RNN和动转静模块都感知,方便后续推动RNN kernel升级后统一删除

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已使用存放在 dygraph.base 里的 NON_PERSISTABLE_VAR_NAME_SUFFIX 来进行标识

@@ -27,6 +27,7 @@
import multiprocessing
import sys
import logging
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接导入顶层 paddle 可能会有circle import 的风险?推荐使用 relative import 的形式。从规范上来讲,底层module 不建议导入顶层 module。但目前看框架里似乎没有遵循这个规范,挺多地方都直接import paddle了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是为了导入 .backward.append_backward 而引入的,但使用 from . import backward 反而会触发 circle import,所以不得不将这个 import 放进函数内部了

2742195759
2742195759 previously approved these changes Mar 29, 2023
Copy link
Contributor

@2742195759 2742195759 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -66,6 +66,8 @@ set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_train_step_resnet18_sgd PROPERTIES TIMEOUT 1200)
set_tests_properties(test_train_step_resnet18_adam PROPERTIES TIMEOUT 1200)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个单测的时间确实需要这么久的时间吗?

Copy link
Member Author

@SigureMo SigureMo Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,试过 900 不太行,就改成 1200 了,不过这个目前只在 PR-CI-Windows-OPENBLAS 会花这么多时间,设为 900 的时候其他流水线都过了

Copy link
Member Author

@SigureMo SigureMo Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已取消这里的 TIMEOUT 设置,暂时在 CPU disable 掉这两个单测,将会在以后性能问题解决后加回来

UPDATE:

取消不太行,GPU 也过不了,暂时降低时间为 120,看看能不能过,不能过的话稍微提高到合适的时间

@SigureMo SigureMo dismissed stale reviews from 2742195759 and Aurelius84 via 7183931 March 29, 2023 12:30
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测执行时间设置

@SigureMo SigureMo merged commit 7728efb into PaddlePaddle:develop Apr 4, 2023
@SigureMo SigureMo deleted the train-step-siguremo branch April 4, 2023 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants