Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add swa_utils modules #9781

Merged
merged 43 commits into from
Jan 29, 2023
Merged

add swa_utils modules #9781

merged 43 commits into from
Jan 29, 2023

Conversation

process852
Copy link
Contributor

This PR mainly add swa_utils to match pytorch, and some other small modifies.
具体内容请查看下面两个issue
swa_utils详细内容
oneflow.profiler.kineto_available

from oneflow.test_utils.automated_test_util import *


class TestLRScheduler(flow.unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个测试如果是搬的也可以加个链接。

from oneflow.nn import Module
from oneflow.nn.optimizer.lr_scheduler import LRScheduler

__all__ = ["AveragedModel", "update_bn", "SWALR"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的实现也可以加一个链接

use_buffers (bool): if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``False``)

Example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码示例不符合oneflow的文档规范,我们要求的是文档上的代码片段copy出来都是独立可以运行的,这里import都没有。

big_suite = unittest.TestSuite(suites_list)
return big_suite


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

评估一下这个测试的耗时,如果耗时很久可以扔到 oneflow/python/oneflow/test/expensive 这个文件里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在本地测试了以下这个大概只需要5-6秒之间就可以完成,不是很耗时。
image

@@ -126,6 +126,11 @@ def forward(self, x):
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked.add_(1)
# add cumulative moving average to match pytorch
Copy link
Contributor

@daquexian daquexian Jan 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释可以删掉,更适合作为 PR 的 comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@process852 process852 removed the request for review from oneflow-ci-bot January 29, 2023 06:12
@process852 process852 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot January 29, 2023 07:14
@mergify mergify bot merged commit 1442b30 into Oneflow-Inc:master Jan 29, 2023
preactivation_var = preactivation_var - preactivation_mean ** 2

update_bn(dl_xy, dnn, device=x.device)
self.assertTrue(flow.allclose(preactivation_mean, dnn.bn.running_mean))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI 在这里有概率报错,可以提高下 rtol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我把计算均值这一部分的rtol统一设置成 1e-3

update_bn(dl_xy, dnn, device=x.device)
self.assertTrue(flow.allclose(preactivation_mean, dnn.bn.running_mean))
self.assertTrue(
flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里选择指定 atol 而不是 rtol 的原因是什么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我查看一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里选择指定 atol 而不是 rtol 的原因是什么

这里计算方差的时候我之前看的计算出来的数据误差会比较大(1e-2,1e-3这样),经常有时候跑不通,所以直接参考的 torch 测试样例中的数值 atol=1e-1, rtol=0。为了避免CI偶发出错的,可以把 rtol 设置为 1e-1,提高误差范围区间。

for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertTrue(flow.max(flow.abs(p_avg - p_swa)) < 1e-5)
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
self.assertTrue(flow.max(flow.abs(b_avg - b_swa)) < 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里手动比较 max diff 的做法也不太常见?如果没有特殊原因应该用 np.allclose(..., rtol=xxx, atol=xxx) 代替。rtol 和 atol 里更常用的是 rtol,atol 主要用在容忍数据类型精度不够带来的误差上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是之前我没注意到 flow.allcose 函数,后面的该了,这些我可能改忘了,我待会一起改一下

Comment on lines +492 to +507
def suite():
test_classes_to_run = [TestLRScheduler, TestSWAUtils]
loader = unittest.TestLoader()

suites_list = []
for test_class in test_classes_to_run:
suite = loader.loadTestsFromTestCase(test_class)
suites_list.append(suite)

big_suite = unittest.TestSuite(suites_list)
return big_suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些是什么用意,为什么不用默认的 unittest.main()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在 #9632 这里改成 unittest.main() 了,经过测试效果是一样的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,但是我在拉取本地仓库最新代码测试该脚本的时候还是会出现偶发的错误,显示的是Segmentation fault。目前我定位了一下好像是TestLRScheduler测试类下的_test方法会出现上述错误,_test方法被test_multiplicative_lr方法调用。

def _test(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, LRScheduler):
            schedulers = [schedulers]
        for epoch in range(epochs):
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertTrue(
                    flow.allclose(
                        flow.tensor(target[epoch]),
                        flow.tensor(param_group["lr"]),
                        atol=1e-6,
                        rtol=1e-5,
                    ),
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[epoch], param_group["lr"]
                    ),
                )
            [scheduler.step() for scheduler in schedulers]

在不添加test_multiplicative_lr测试单元时,我在本地重复运行 test_swautils.py 20次都是通过的,加上test_multiplicative_lr测试单元时大部分情况下都可以通过,但是仍会偶尔出现 Segmentation fault

利用/var/lib/apport/coredump目录下产生的core文件,命令 gdb python core文件名得到如下提示信息:
image

之后我怀疑是不是 [scheduler.step() for scheduler in schedulers]这部分代码有问题,于是想看看是那个更新出来问题,添加了打印信息,在失败的情况下会停止在[scheduler.step() for scheduler in schedulers]代码处

print("step before")
[scheduler.step() for scheduler in schedulers]
print("step after")

image

检查了一下 step 方法觉得没有问题,不知道是不是CPP部分出现了问题。
CPP源码部分提示的信息需要帮忙指导一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我来调试一下吧,你可以先把其它 comment 处理掉并提交一个 PR 先合并

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

mergify bot added a commit that referenced this pull request Feb 8, 2023
1. 将原来使用fabs判断误差改用为`flow.allclose`函数
2. 删除了重复的测试单元`test_multiplicative_lr`(原来写了2个一样的)
3.
提高了部分flow.allclose中rtol的值,解决概率性不通过的问题#9781 (comment)

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants