-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No57】add fp16 & bf16 for max_pool2d_with_index, max_pool3d_with_index #52314
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
def test_check_output(self): | ||
place = core.CUDAPlace(0) | ||
if core.is_bfloat16_supported(place): | ||
self.check_output_with_place(place, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16类型的默认值就是1e-2,这个地方可以不设置这个阈值
) | ||
class TestMaxPool3dBF16(parent): | ||
def init_dtype(self): | ||
self.dtype = np.uint16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16类型的单测不用设置mask_type么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有设置一下试试么?
) | ||
class TestMaxPool2dBF16(parent): | ||
def init_dtype(self): | ||
self.dtype = np.uint16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16类型的单测不用设置mask_type么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16类型的单测不用设置mask_type么
这个有设置一下试试么?
def test_check_output(self): | ||
place = core.CUDAPlace(0) | ||
if core.is_bfloat16_supported(place): | ||
self.check_output_with_place(place, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16阈值默认值为1e-2,这里可以不用设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,辛苦再review下~
0f9c524
to
4637224
Compare
很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。 |
@Vvsmile done,辛苦再review下~ |
input = np.round(input * 100.0, 2) | ||
if self.is_bfloat16_op(): | ||
input = input.astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里先初始化为uint16,再cast为float32,有点不太合适,正确的顺序应该是先初始化为float32,计算完结果之后,再将输入和输出使用convert_float_to_uint16转为uint16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于pool_max算子的计算方式,input先在fp32计算,后转为bf16去计算,这两种输入因为精度损失无法对应起来,pool_max除了计算出max的值外,也需要max的index,精度损失使得index无法对应,目前没有想到更好的处理方式,看了一下torch,好像也是先bf16的input做的计算,再将input转成float计算做对照https://github.com/pytorch/pytorch/blob/31f311a816c026bbfca622d6121d6a7fab44260d/test/nn/test_pooling.py#L927
|
||
if self.is_bfloat16_op(): | ||
output = output.astype(np.float32) | ||
mask = mask.astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask还应该是int32吧,无需convert
076cb47
to
f365319
Compare
f365319
to
2f9e3c5
Compare
2f9e3c5
to
78e423c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…_with_index (PaddlePaddle#52314) * add fp_bf for pool_max_withidx * fix some error * fix error * codestyle error * fix masktype * fix input bf type * input bf dtype convert error * back to convert input to bf16 first * fix convert error * fix bf16 grad check
PR types
Others
PR changes
APIs
Description
max_pool2d_with_index, 增加FP16,BF16支持,完善单测
max_pool3d_with_index, 增加FP16,BF16支持,完善单测