Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer、ModelAverage、LookaheadOptimizer、FtrlOptimizer、DecayedAdagradOptimizer、DpsgdOptimizer in fluid and relocate the ExponentialMovingAverage、PipelineOptimizer、GradientMergeOptimizer and change optimizer base for LarsMomentumOptimizer and RecomputeOptimizer #55970

Merged
merged 80 commits into from
Aug 9, 2023

Conversation

longranger2
Copy link
Contributor

@longranger2 longranger2 commented Aug 3, 2023

PR types

Others

PR changes

APIs

Description

进行移除的优化器如下:

  • remove the ModelAverage in paddle/fluid/optimizer.py and use python/paddle/incubate/optimizer/modelaverage.py to replace it.
  • remove the LookaheadOptimizer in paddle/fluid/optimizer.py and use the LookAhead in paddle/incubate/optimizer/lookahead.py to replace it.
  • remove the AdamOptimizer in paddle/fluid/contrib/optimizer.py and use python/paddle/optimizer/adam.py to replace it.
  • remove the SGDOptimizer in paddle/fluid/contrib/optimizer.py and use python/paddle/optimizer/sgd.py to replace it.
  • remove the MomentumOptimizer in paddle/fluid/contrib/optimizer.py and use python/paddle/optimizer/momentum.py to replace it.
  • remove the FtrlOptimizer in paddle/fluid/optimizer.py.
  • remove the DecayedAdagradOptimizer in paddle/fluid/optimizer.py.
  • remove the DpsgdOptimizer in paddle/fluid/optimizer.py.

进行迁移的优化器如下:

  • relocate the ExponentialMovingAverage in paddle/fluid/optimizer.py
  • relocate the PipelineOptimizer in paddle/fluid/optimizer.py
  • relocate the GradientMergeOptimizer in paddle/fluid/optimizer.py

进行迁移并更新基类的优化器如下:

  • relocate the LarsMomentumOptimizer in paddle/fluid/optimizer.py and change optimizer base
  • relocate the RecomputeOptimizer in paddle/fluid/optimizer.py and change optimizer base

@paddle-bot
Copy link

paddle-bot bot commented Aug 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Aug 3, 2023
@longranger2 longranger2 changed the title remove the adam、sgd、momentum in fluid remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer in fluid Aug 3, 2023
@longranger2 longranger2 changed the title remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer in fluid remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer in fluid and change optimizer base for LarsMomentumOptimizer and RecomputeOptimizer Aug 4, 2023
@longranger2 longranger2 changed the title remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer in fluid and change optimizer base for LarsMomentumOptimizer and RecomputeOptimizer remove the AdamOptimizer、SGDOptimizer、MomentumOptimizer、ModelAverage、LookaheadOptimizer、FtrlOptimizer、DecayedAdagradOptimizer、DpsgdOptimizer in fluid and relocate the ExponentialMovingAverage、PipelineOptimizer、GradientMergeOptimizer and change optimizer base for LarsMomentumOptimizer and RecomputeOptimizer Aug 5, 2023
@longranger2 longranger2 mentioned this pull request Aug 5, 2023
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测删除

found_inf = self._get_auxiliary_var('found_inf')

if found_inf:
inputs['SkipUpdate'] = found_inf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处是被单测 test_mixed_precision检出. 目前只有 fluid.optimizer.Adam and paddle.optimizer.AdamW添加了检测inf跳过更新的策略(见paddle.static.amp.decorator)。但Adam OP是支持这个输入的。因此在paddle.optimizer.Adam添加该策略支持。

@@ -609,12 +609,6 @@ def test_sharding_weight_decay(self):
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'scale',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处是因为paddle.optimizer.Momemtum会将L2Decay fuse到OP中,因此不存在额外的scale+sum操作

return optimizer

def get_optimizer(self):
optimizer = paddle.optimizer.Lamb(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里删除部分2.0优化器paddle.optimizer.xxx的单测,是因为在单测v2版本中存在吗,这两个单测文件间的关系是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_imperative_optimizer.py是用来测试旧版的优化器的,test_imperative_optimizer_v2.py主要用来测试2.0版本的优化器的,但在一开始改动test_imperative_optimizer.py文件的时候只是将优化器替换为2.0的实现,并没有将其删除,所以导致现在两个文件越来相似,后面考虑将test_imperative_optimizer.py进行删除

9.8336181640625,
8.22379207611084,
8.195695877075195,
10.508796691894531,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.0版本的paddle.optimizer.Momentum会自动fuse L2Decay,此时和fluid存在数值差异。其他Decay或No Decay结果一致,本单测属于前一种场景。

9.569124221801758,
8.251557350158691,
8.513609886169434,
10.603094100952148,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

9.559739112854004,
8.430597305297852,
8.109201431274414,
10.224763870239258,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -68,7 +68,7 @@ def test_trainable(self):
self.check_trainable(
test_trainable,
feed_dict,
op_count={'adam': 1, 'scale': 0, 'mul_grad': 0},
op_count={'adam': 1, 'scale': 0, 'mul_grad': 1},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该处理论上不应该改动,且下面一个关于adamax的检查已经被错误的修改。经查,属于静态图下Parameter未能正确设置stop_gradient属性导致(fluid.Optimizer检查的是trainable属性而非stop_gradientpaddle.Optimizer则相反)。

后续需要另提PR修复这个问题

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, trainable问题专门解决

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for amp

Copy link
Contributor

@6clc 6clc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

  • ExponentialMovingAverage 相应的中文文档是否需要更改位置?需要的话中文这边也提个PR叭 @longranger2
  • 文档预览的ci现在出了点bug,暂时无法预览,之后预览有问题我再在这个comment下回复~

@longranger2
Copy link
Contributor Author

longranger2 commented Aug 8, 2023

  • ExponentialMovingAverage 相应的中文文档是否需要更改位置?需要的话中文这边也提个PR叭 @longranger2

ExponentialMovingAverage 相应的中文文档应该是不需要更改位置,虽然对ExponentialMovingAverage进行迁移,但对外暴露的接口还是paddle.static.ExponentialMovingAverage,与文档是统一的~ @sunzhongkai588

image

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404 jeff41404 merged commit 723c6f7 into PaddlePaddle:develop Aug 9, 2023
@longranger2 longranger2 deleted the adam_sgd_momentum branch August 9, 2023 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants