Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.28] implement logcumsumexp #42267

Merged
merged 27 commits into from
Jun 10, 2022

Conversation

tiancaishaonvjituizi
Copy link
Contributor

PR types

New features

PR changes

APIs

Describe

实现 logcumsumexp

hackathon issue 链接:#40305
hackathon PR 链接:PaddlePaddle/community#82

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old paddle-bot-old bot added contributor External developers status: proposed labels Apr 26, 2022
@tiancaishaonvjituizi tiancaishaonvjituizi changed the title implement logcumsumexp [Hackathon No.28] implement logcumsumexp Apr 26, 2022
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}
Copy link
Contributor Author

@tiancaishaonvjituizi tiancaishaonvjituizi Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本文件从这一行往上的内容是从 cumsum_kernel.cc 移动过来的,增加了 Reducer 参数

ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
}
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

};

template <typename T>
struct Identity<T, LogAddExp> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

二元运算 LogAddExp 所形成的幺半群的幺元

}
return;
}

Copy link
Contributor Author

@tiancaishaonvjituizi tiancaishaonvjituizi Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thrust 的实现删掉了,因为 cub 的介绍显示它的 prefix scan 比 thrust 快很多,tf 也没有用 thrust

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 建议本PR不要动已有算子的实现部分,可能存在未知的精度/性能问题。而且RFC文档里写对其他模块没有影响(包括正负影响)。
  2. 可直接单独提一个PR来修改,通过回归测试CE来验证Op性能和模型精度。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiancaishaonvjituizi 可以先行单独提一个PR来提升cumsum的速度,并欢迎报名参加 【PFCC-Roadmap】算子性能优化 活动,详见 #42286

bool exclusive,
bool reverse,
MetaTensor* out) {
void CumInferMeta(const MetaTensor& x,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cumsum 和 Logcumsumexp 复用同一个 infer meta 函数

return x


class TestLogcumsumexpOp(unittest.TestCase):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前没有检查梯度,因为 import OpTest 时报错:

Traceback (most recent call last):
  File "test_logcumsumexp_op.py", line 24, in <module>
    from op_test import OpTest
  File "/home/dev/files/repos/Paddle2/python/paddle/fluid/tests/unittests/op_test.py", line 40, in <module>
    from paddle.fluid.tests.unittests.testsuite import (
ModuleNotFoundError: No module named 'paddle.fluid.tests'

是不是我哪里使用方式不对呢,也没有看到相关的文档

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问是用什么方式运行单测呢,是build目录下,ctest -R test_logcumsumexp_op 来运行的么?https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/new_python_api_cn.html#yunxingdanyuanceshi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞,之前没有找到这个文档

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问是用什么方式运行单测呢,是build目录下,ctest -R test_logcumsumexp_op 来运行的么?https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/new_python_api_cn.html#yunxingdanyuanceshi

@luotao1 FYI,这个文档有几个关键链接是错误的😂,paddle 对文档的上心程度可能需要提升一下。我提交了 PR 在 PaddlePaddle/docs#4742 。本 PR 我接下来就继续更新

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 6, 2022

PR格式检查通过,你的PR将接受Paddle专家以及开源社区的review,请及时关注PR动态。
The format inspection passed. Your PR will be reviewed by experts of Paddle and developers from the open-source community. Stay tuned.

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有review完,明天继续

paddle/utils/variant.h Outdated Show resolved Hide resolved
paddle/fluid/operators/cum_op.cc Outdated Show resolved Hide resolved
return x


class TestLogcumsumexpOp(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问是用什么方式运行单测呢,是build目录下,ctest -R test_logcumsumexp_op 来运行的么?https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/new_python_api_cn.html#yunxingdanyuanceshi

paddle/phi/kernels/cpu/cum_kernel.cc Outdated Show resolved Hide resolved
}
return;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 建议本PR不要动已有算子的实现部分,可能存在未知的精度/性能问题。而且RFC文档里写对其他模块没有影响(包括正负影响)。
  2. 可直接单独提一个PR来修改,通过回归测试CE来验证Op性能和模型精度。

paddle/phi/kernels/gpu/cum_kernel.cu Show resolved Hide resolved
python/paddle/tensor/math.py Show resolved Hide resolved
python/paddle/tensor/math.py Show resolved Hide resolved
@tiancaishaonvjituizi
Copy link
Contributor Author

@luotao1 单测已补充,但梯度检查因为精度原因过不了,有什么办法稍微调大阈值吗,check_grad 的max_relative_error 参数在函数内部被强行覆盖了,所以设置它没有效果

paddle/fluid/operators/cum_op.cc Outdated Show resolved Hide resolved
python/paddle/tensor/math.py Outdated Show resolved Hide resolved
@luotao1
Copy link
Contributor

luotao1 commented May 10, 2022

单测已补充,但梯度检查因为精度原因过不了,有什么办法稍微调大阈值吗,check_grad 的max_relative_error 参数在函数内部被强行覆盖了,所以设置它没有效果

def test_check_grad_normal(self):
self.check_grad(
['X', 'Filter'],
'Out',
max_relative_error=0.06,
check_dygraph=False)

我们内部有很多使用max_relative_error,可以搜索一下

@tiancaishaonvjituizi
Copy link
Contributor Author

tiancaishaonvjituizi commented May 12, 2022

我们内部有很多使用max_relative_error,可以搜索一下

@luotao1

那是我需要把 logcumsumexp 加入到 NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST 里吗,从名字来看我以为这是一个临时的补丁(NEED_FIX)

if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST:
numeric_grad_delta = 1e-5
max_relative_error = 1e-7

此外,这里在没有任何提示的情况下强行覆盖了 numeric_grad_delta 和 max_relative_error,在用户不知晓的情况下违反了用户的意图,我认为是不应该提倡的。如果可以的话,我再提一个 PR 把这里的行为改为抛出一个异常

@luotao1
Copy link
Contributor

luotao1 commented May 12, 2022

NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST

确实是一个临时的补丁(里面的存量问题会修)

这里在没有任何提示的情况下强行覆盖了 numeric_grad_delta 和 max_relative_error,在用户不知晓的情况下违反了用户的意图,我认为是不应该提倡的

这是Op精度规范的要求:单测精度中atol, rtol, eps, max_relative_error, 不允许自行放大阈值

OP单测中检查前向输出和反向梯度精度时,存在放大阈值通过单测的问题。为了更好得保证Op质量,提出了本条规范,并在CI中添加了相应的检查方法。

为什么设置numeric_grad_delta = 1e-5 ,max_relative_error = 1e-7

可以看这个规范:OP单测精度升级到float64

整体Op的开发规范和单测规范,正在整理到 API测试和验收规范@DDDivano

总结下:为了保证Op精度,原则上不允许在单测中放大阈值来通过单测(即对用户透明)。如果实在过不了,使用max_relative_error后,CI会进行拦截,会有专门的同学进行精度审核。

@tiancaishaonvjituizi
Copy link
Contributor Author

tiancaishaonvjituizi commented May 13, 2022

总结下:为了保证Op精度,原则上不允许在单测中放大阈值来通过单测(即对用户透明)。如果实在过不了,使用max_relative_error后,CI会进行拦截,会有专门的同学进行精度审核。

我有一个觉得不妥的地方是它不应该【没有任何提示地】不执行用户的意图,正确的行为应该是这样:如果用户指定了 max_relative_error,但 op 不在那个 list 里,应该以报错或者警告的方式提醒用户 “你不应该修改 max_relative_error,我们会把 max_relative_error 覆盖为 1e-7,如果一定要修改,那请按照下面的步骤 ....”,而不是在不告知用户的情况下直接覆盖。这种“轻率”的行为会引起下游不必要的调试成本,也不是一个工程实践中该被提倡的做法(直接违反了最小惊讶原则

logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij})

Note:
The first element of the result is the same of the first element of the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里预览没有触发Note的样式,建议加一个缩进
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已修复


.. math::

logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么英文文档有公式,中文文档没有呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在 PaddlePaddle/docs#4807 修复

@@ -2908,7 +2908,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
The cumulative sum of the elements along a given axis.

**Note**:
The first element of the result is the same of the first element of the input.
The first element of the result is the same as the first element of the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我意思是加一个缩进即可,参考:

Note:
    The first element of the result is the same as the first element of the input. 

Note 不需要加粗的~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,文档修改可以在commit命名的时候加上 ;test=document_fix ,可以跳过代码检查的CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复,可以再看一下 @Ligoml

@luotao1
Copy link
Contributor

luotao1 commented May 27, 2022

2022-05-27 17:10:12 + git merge --no-edit develop
2022-05-27 17:10:12 fatal: refusing to merge unrelated histories

请merge下最新的develop分支重新提交

@tiancaishaonvjituizi
Copy link
Contributor Author

请merge下最新的develop分支重新提交

好的,已 merge @luotao1

Ligoml
Ligoml previously approved these changes May 30, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

with fluid.program_guard(fluid.Program()):

with self.assertRaises(TypeError):
data_np = np.random.random((100, 100), dtype=np.int32)
Copy link
Contributor

@luotao1 luotao1 May 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测超时的频率还是很多,我们每个PR内会retry三次,三次都不过才会报失败。如果减少不了单测时间,可以使用Timeout属性来增加单测时间(默认是15S)

list(REMOVE_ITEM TEST_OPS test_warpctc_op)

py_test_modules(test_warpctc_op MODULES test_warpctc_op)
set_tests_properties(test_warpctc_op PROPERTIES TIMEOUT 120)

@luotao1
Copy link
Contributor

luotao1 commented Jun 9, 2022

请使用precommit修复下static-check流水线中的代码格式问题

@tiancaishaonvjituizi
Copy link
Contributor Author

@luotao1 已修复

Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_logcumsumexp_op PROPERTIES TIMEOUT 30)

@luotao1 luotao1 merged commit 19a7524 into PaddlePaddle:develop Jun 10, 2022
@tiancaishaonvjituizi tiancaishaonvjituizi deleted the logcumsumexp branch June 10, 2022 10:23
@tiancaishaonvjituizi
Copy link
Contributor Author

伟大!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants