-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug Report] sigmoid_cross_entropy_with_logits 算子的小算子自动微分与调用反向kernel的计算结果不一致 #64226
Comments
分析部分有点问题,由于在推导过程中忽略了前向计算中使用的 经过推导得到的反向梯度计算为: 其中 对应的修复PR: |
kernel反向计算的结果,向numpy中采用数值求解的方式(见源码:op_test.py#L148-L323)计算的结果对齐,而拆解算子执行梯度的方式是通过自动微分求解的,其与kernel反向计算结果对齐。推断是Kernel反向实现的计算,存在问题。验证如下: 在执行 W0515 05:27:49.445072 36810 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.4, Runtime API Version: 11.2
W0515 05:27:49.449699 36810 gpu_resources.cc:164] device: 0, cuDNN Version: 8.1.
numeric :
[array([[ 3.77183495e-04, -2.20797240e-04, -3.14791652e-04, ...,
-2.02644761e-04, 2.45672779e-04, -6.17090210e-06],
[ 4.84435251e-04, 2.84202716e-04, 1.83931716e-05, ...,
6.29521346e-04, 5.20318281e-04, 2.33842612e-05],
[ 3.12574747e-04, -4.71084098e-04, 1.10182442e-04, ...,
6.98864401e-04, 2.33956572e-04, -7.56920161e-05],
...,
[ 7.15934865e-04, -3.74937504e-04, 3.26225586e-04, ...,
3.84216391e-05, -5.20641936e-04, -4.17575856e-04],
[ 1.96946960e-05, 3.88698082e-04, -2.81023718e-04, ...,
-5.38852117e-05, 3.67850861e-04, -1.84393860e-04],
[-9.46350590e-05, 1.44749951e-05, -2.59066396e-04, ...,
5.43415898e-04, 5.17161748e-05, 5.20940836e-04]])]
analytic_grads :
[array([[ 1.84299020e-04, -2.20797405e-04, -3.14791844e-04, ...,
-2.02644764e-04, 1.03724500e-04, -3.28235993e-04],
[ 1.30288861e-04, -4.92659295e-04, -9.86970736e-05, ...,
3.65435473e-04, 4.38155145e-04, -7.09606712e-04],
[-1.62492418e-04, -4.71084140e-04, 1.10182313e-04, ...,
1.25990168e-04, 1.87285167e-04, -5.22634377e-04],
...,
[-3.11346327e-05, -3.74937645e-04, 3.26225574e-04, ...,
-3.02651920e-04, -5.20642019e-04, -4.17576055e-04],
[-3.08952871e-04, 2.82421633e-04, -2.81023766e-04, ...,
-5.38854093e-05, 6.31943427e-05, -1.84394142e-04],
[-1.67904546e-04, -1.19940036e-05, -2.59066405e-04, ...,
3.36552688e-04, 2.25882243e-05, -9.09629301e-05]])]
max_relative_error :
0.005
.
----------------------------------------------------------------------
Ran 1 test in 2.453s
OK 但是当我把这个容忍阈值改为 I0515 06:12:58.692179 17707 program_interpreter.cc:221] New Executor is Running.
I0515 06:12:58.693336 17707 interpreter_util.cc:652] Standalone Executor is Used.
numeric :
[array([[ 3.77183495e-04, -2.20797240e-04, -3.14791652e-04, ...,
-2.02644761e-04, 2.45672779e-04, -6.17090210e-06],
[ 4.84435251e-04, 2.84202716e-04, 1.83931716e-05, ...,
6.29521346e-04, 5.20318281e-04, 2.33842612e-05],
[ 3.12574747e-04, -4.71084098e-04, 1.10182442e-04, ...,
6.98864401e-04, 2.33956572e-04, -7.56920161e-05],
...,
[ 7.15934865e-04, -3.74937504e-04, 3.26225586e-04, ...,
3.84216391e-05, -5.20641936e-04, -4.17575856e-04],
[ 1.96946960e-05, 3.88698082e-04, -2.81023718e-04, ...,
-5.38852117e-05, 3.67850861e-04, -1.84393860e-04],
[-9.46350590e-05, 1.44749951e-05, -2.59066396e-04, ...,
5.43415898e-04, 5.17161748e-05, 5.20940836e-04]])]
analytic_grads :
[array([[ 1.84299020e-04, -2.20797405e-04, -3.14791844e-04, ...,
-2.02644764e-04, 1.03724500e-04, -3.28235993e-04],
[ 1.30288861e-04, -4.92659295e-04, -9.86970736e-05, ...,
3.65435473e-04, 4.38155145e-04, -7.09606712e-04],
[-1.62492418e-04, -4.71084140e-04, 1.10182313e-04, ...,
1.25990168e-04, 1.87285167e-04, -5.22634377e-04],
...,
[-3.11346327e-05, -3.74937645e-04, 3.26225574e-04, ...,
-3.02651920e-04, -5.20642019e-04, -4.17576055e-04],
[-3.08952871e-04, 2.82421633e-04, -2.81023766e-04, ...,
-5.38854093e-05, 6.31943427e-05, -1.84394142e-04],
[-1.67904546e-04, -1.19940036e-05, -2.59066405e-04, ...,
3.36552688e-04, 2.25882243e-05, -9.09629301e-05]])]
max_relative_error :
0.0005
F
======================================================================
FAIL: test_check_grad (__main__.TestSigmoidCrossEntropyWithLogitsOp4)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/paddle/test/deprecated/legacy_test/test_sigmoid_cross_entropy_with_logits_op.py", line 178, in test_check_grad
self.check_grad(['X'], 'Out', check_pir=True)
File "/paddle/build/test/legacy_test/op_test.py", line 2986, in check_grad
self.check_grad_with_place(
File "/paddle/build/test/legacy_test/op_test.py", line 3298, in check_grad_with_place
numeric_grads = self.check_grad_with_place_for_static(
File "/paddle/build/test/legacy_test/op_test.py", line 3089, in check_grad_with_place_for_static
self._assert_is_close(
File "/paddle/build/test/legacy_test/op_test.py", line 2942, in _assert_is_close
self.assertLessEqual(max_diff, max_relative_error, err_msg())
AssertionError: 0.0007811970012982192 not less than or equal to 0.0005 : Operator sigmoid_cross_entropy_with_logits error, Gradient Check On Place(cpu) variable X (shape: (64, 20), dtype: float64) max gradient diff 7.811970e-04 over limit 5.000000e-04, the first error element is 3, expected 5.481218e-04, but got 2.099690e-05.
----------------------------------------------------------------------
Ran 1 test in 0.521s
FAILED (failures=1) 因此可以推断,是由于容忍阈值比较大,所以使得反向计算错误的问题没有暴露出来。 在修复pr将 W0515 06:13:53.214535 18318 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.4, Runtime API Version: 11.2
W0515 06:13:53.220155 18318 gpu_resources.cc:164] device: 0, cuDNN Version: 8.1.
numeric :
[array([[ 3.77183495e-04, -2.20797240e-04, -3.14791652e-04, ...,
-2.02644761e-04, 2.45672779e-04, -6.17090210e-06],
[ 4.84435251e-04, 2.84202716e-04, 1.83931716e-05, ...,
6.29521346e-04, 5.20318281e-04, 2.33842612e-05],
[ 3.12574747e-04, -4.71084098e-04, 1.10182442e-04, ...,
6.98864401e-04, 2.33956572e-04, -7.56920161e-05],
...,
[ 7.15934865e-04, -3.74937504e-04, 3.26225586e-04, ...,
3.84216391e-05, -5.20641936e-04, -4.17575856e-04],
[ 1.96946960e-05, 3.88698082e-04, -2.81023718e-04, ...,
-5.38852117e-05, 3.67850861e-04, -1.84393860e-04],
[-9.46350590e-05, 1.44749951e-05, -2.59066396e-04, ...,
5.43415898e-04, 5.17161748e-05, 5.20940836e-04]])]
analytic_grads :
[array([[ 3.77183699e-04, -2.20797405e-04, -3.14791844e-04, ...,
-2.02644764e-04, 2.45672821e-04, -6.17087178e-06],
[ 4.84435362e-04, 2.84202718e-04, 1.83934196e-05, ...,
6.29521507e-04, 5.20318417e-04, 2.33842613e-05],
[ 3.12574821e-04, -4.71084140e-04, 1.10182313e-04, ...,
6.98864469e-04, 2.33956823e-04, -7.56920088e-05],
...,
[ 7.15934877e-04, -3.74937645e-04, 3.26225574e-04, ...,
3.84217363e-05, -5.20642019e-04, -4.17576055e-04],
[ 1.96948667e-05, 3.88698345e-04, -2.81023766e-04, ...,
-5.38854093e-05, 3.67850951e-04, -1.84394142e-04],
[-9.46347782e-05, 1.44751433e-05, -2.59066405e-04, ...,
5.43416127e-04, 5.17164533e-05, 5.20940836e-04]])]
max_relative_error :
0.0005
.
----------------------------------------------------------------------
Ran 1 test in 2.753s
OK |
bug描述 Describe the Bug
在实现
sigmoid_cross_entropy_with_logits
op的拆解时,用paddle
api去实现对应的功能,前向计算得到相同的结果,但是反向计算时产生了精度问题,推测是小算子的自动微分和算子反向计算kernel存在差异。复现代码如下:
BUG截图:
其他补充信息 Additional Supplementary Information
目前基本可以判断BUG产生的原因在于
pos_weight
的引入,当不存在可选参数pos_weight
时,默认使用全1的 Tensor 代替,这时候自动微分和kernel反向计算的结果一致,但是当他们不是全1时,结果就会产生偏差。具体分析如下:
kernel中有关pos_weight部分前向计算的代码:
可以用公式表示为:
对公式求x的偏导如下:
但反向计算的代码如下:
对应的公式如下:
所以才会在posWeight不为全1Tensor的时候产生差异,不知道我的分析是否正确,希望能够查看一下。
The text was updated successfully, but these errors were encountered: