Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.3】为 Paddle 新增 masked_fill API #57355

Merged
merged 34 commits into from
Nov 2, 2023

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Sep 15, 2023

PR types

Others

PR changes

Others

Description

为 Paddle 新增 masked_fill API

RFC 文档:https://github.com/PaddlePaddle/community/pull/616/files

@paddle-bot
Copy link

paddle-bot bot commented Sep 15, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Sep 15, 2023
@Ligoml Ligoml changed the title 为 Paddle 新增 masked_fill API 【Hackathon 5th No.3】为 Paddle 新增 masked_fill API Sep 15, 2023
@AndSonder AndSonder changed the title 【Hackathon 5th No.3】为 Paddle 新增 masked_fill API 为 Paddle 新增 masked_fill API Sep 15, 2023
@AndSonder AndSonder changed the title 为 Paddle 新增 masked_fill API 【Hackathon 5th No.3】为 Paddle 新增 masked_fill API Sep 19, 2023
@AndSonder
Copy link
Contributor Author

AndSonder commented Sep 21, 2023

@zoooo0820 麻烦研发老师review一下~ , 除了 doc ci 的问题应该都可以了

# broadcast_mask = paddle.cast(broadcast_mask, 'bool')

# if in_dynamic_mode():
# return _C_ops.where_(broadcast_mask, broadcast_x, broadcast_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请移除非说明类的注释内容

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

value (Scaler or 0-D Tensor): The value used to fill the target tensor.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数说明中写明支持的dtype,且要和设计文档中的一致; Scaler -> Scalar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

value = paddle.full([1], value, x.dtype)

mask = paddle.bitwise_not(mask)
out = paddle.where(mask, x, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否可以把value和x对调,省去一个not op操作

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是不能对调的,paddle.where(cond, x, y) 在 inplace 的时候对boardcast的处理如下:

        zeros_like_x = paddle.zeros_like(x)
        zeros_like_y = paddle.zeros_like(y)
        zeros_like_condition = paddle.zeros_like(condition)
        zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
        cast_cond = paddle.cast(condition, x.dtype)

        broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
        broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
        broadcast_x = x.add_(broadcast_zeros)
        broadcast_y = paddle.add(y, broadcast_zeros)
        broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
        broadcast_condition = paddle.cast(broadcast_condition, 'bool')

其中 broadcast_x = x.add_(broadcast_zeros) 用了 inplace 的操作,如果调换位置就会导致 masked_fill 和 mask_fill_ 运行结果不一致问题。broadcast 也会出问题。 我一开始写的时候写的时候也是用的 paddle.where(mask, value, x) 但是跑单侧的时候就很多报错,inplace 和 非 inplace 的梯度信息也不一样

[2., 2., 1.]])
"""
if np.isscalar(value):
value = paddle.full([1], value, x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里虽然不影响结果,从Scalar的语义来看,是否应该full([], ...) 构建一个0-D Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

value = paddle.full([1], value, x.dtype)

mask = paddle.bitwise_not(mask)
out = paddle.where_(mask, x, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,是否可以通过调换x , value 来避免额外的not操作

fetch_list=[out],
)
np.testing.assert_allclose(
res[0], self.out_np, atol=1e-5, rtol=1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方,如果是因为fp16/bf16的特殊case的话,建议给fp16/bf16单独设置下阈值,其他case使用默认阈值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个阈值一般应该设置多少呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个阈值一般应该设置多少呢

默认即可,不要特意减小.

fp16单测过不去的话可以单独拿出来设置下

mask = paddle.to_tensor(self.mask_np).astype('bool')
value = paddle.to_tensor(self.value_np, dtype=self.dtype)
result = paddle.masked_fill(x, mask, value)
np.testing.assert_allclose(self.out_np, result.numpy(), rtol=1e-05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -0,0 +1,187 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件建议不用命名成test_xxx_op, 这个pr中并没有新增这个op,直接test_xxx就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.x_shape = (300, 1)
self.mask_shape = (300, 40)
self.dtype = "float16"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试用例需要和设计文档中的对齐一下:

  • 可以补充一个bf16的case
  • 0-dvalue 的case,需要一起验证下反向

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证反向是指加在 test_inplace 里面的单侧吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证反向是指加在 test_inplace 里面的单侧吗

直接在这个文件中写一个带backward的case,验证value.grad是否为预期即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证反向是指加在 test_inplace 里面的单侧吗

直接在这个文件中写一个带backward的case,验证value.grad是否为预期即可

请问有类似的例子吗,我看其他测试backward的都是新添加op的情况下去测,构建了一个静态图去跑。没有看到python端组合的api去测backward的代码

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证反向是指加在 test_inplace 里面的单侧吗

直接在这个文件中写一个带backward的case,验证value.grad是否为预期即可

请问有类似的例子吗,我看其他测试backward的都是新添加op的情况下去测,构建了一个静态图去跑。没有看到python端组合的api去测backward的代码

添加单测原因是,需要验证下目前组合的API的方式是否满足设计预期的效果,可参考:

(y.grad.numpy().astype('float32') == expected_grad).all(),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证反向是指加在 test_inplace 里面的单侧吗

直接在这个文件中写一个带backward的case,验证value.grad是否为预期即可

请问有类似的例子吗,我看其他测试backward的都是新添加op的情况下去测,构建了一个静态图去跑。没有看到python端组合的api去测backward的代码

添加单测原因是,需要验证下目前组合的API的方式是否满足设计预期的效果,可参考:

(y.grad.numpy().astype('float32') == expected_grad).all(),

好的,我参考看一下,谢谢研发老师

@AndSonder
Copy link
Contributor Author

还麻烦研发老师帮忙看下doc格式和 windows 下 bf16单侧的问题

@AndSonder
Copy link
Contributor Author

AndSonder commented Oct 11, 2023

还麻烦研发老师帮忙看下doc格式和 windows 下 bf16单侧的问题
@AndSonder

  • code-style问题请参考报错信息更新下precommit,或者根据给出的diff信息修改下代码
  • bf16问题,不太好直接确认原因,初步看了下可能跟初始化设置有关,建议参考下其他单测,使用 convert_float_to_uint16 / convert_float_to_uint16 完成转换

辛苦再看下剩余的CI问题:

  • coverage 覆盖率不足,需要补充value非tensor的case
  • codestyle挂了
  • static-check显示有print的使用,辛苦检查下除文档以外的部分是否有

此外 third_party 的变动不要加到这个PR中

  • 已添加value非tensor的case
  • codestyle 已修复
  • 已检查,只有文档部分有print
  • third_party 的改动已去除

@zoooo0820
Copy link
Contributor

@AndSonder
新增的单测中有报错,辛苦再看下;此外覆盖率未过,也辛苦再调整单测case覆盖下

@AndSonder
Copy link
Contributor Author

AndSonder commented Oct 12, 2023

@AndSonder 新增的单测中有报错,辛苦再看下;此外覆盖率未过,也辛苦再调整单测case覆盖下

已经都修复了,辛苦研发老师再看下

@AndSonder
Copy link
Contributor Author

还麻烦研发老师有空的时候帮忙看下还有没有什么其他问题 @zoooo0820

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if np.isscalar(value):
value = paddle.full([], value, x.dtype)

mask = paddle.bitwise_not(mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this scenario, paddle.logical_not is more safe and semantical than paddle.bitwise_not, or if paddle.logical_not is used, it is necessary to check whether the data type of mask is bool, as other types may produce unexpected results.

if np.isscalar(value):
value = paddle.full([], value, x.dtype)

mask = paddle.bitwise_not(mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the above issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

jeff41404
jeff41404 previously approved these changes Oct 19, 2023
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 4569 to 4577
x (Tensor) : The Destination Tensor. Supported data types are float,
double, int, int64_t,float16 and bfloat16.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Scalar or 0-D Tensor): The value used to fill the target tensor.
Supported data types are float, double, int, int64_t,float16 and bfloat16.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x (Tensor) : The Destination Tensor. Supported data types are float,
double, int, int64_t,float16 and bfloat16.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Scalar or 0-D Tensor): The value used to fill the target tensor.
Supported data types are float, double, int, int64_t,float16 and bfloat16.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
x (Tensor) : The Destination Tensor. Supported data types are float,
double, int, int64_t,float16 and bfloat16.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Scalar or 0-D Tensor): The value used to fill the target tensor.
Supported data types are float, double, int, int64_t,float16 and bfloat16.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
image

Comment on lines 4582 to 4583
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code-block 下加空行

Suggested change
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
image

Comment on lines 4612 to 4613
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all done

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,请同步提供中文文档

@AndSonder
Copy link
Contributor Author

AndSonder commented Oct 26, 2023

LGTM,请同步提供中文文档

已提供中文文档

@sunzhongkai588

Copy link

paddle-ci-bot bot commented Nov 1, 2023

Sorry to inform you that bccd1f6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@AndSonder
Copy link
Contributor Author

@luotao1 本 PR 的中文文档也没问题了,还麻烦帮忙看看可以合入了吗

@luotao1 luotao1 merged commit d192740 into PaddlePaddle:develop Nov 2, 2023
28 checks passed
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
* add masked_fill for paddle

* update doc

* update some test case

* remove full_like

* update test codes

* update test cases

* recover codes

* update test codes

* fix gradients error

* update test codes

* fix

* add bf16 test cases

* update code-block

* update code-block

* update test codes

* Update __init__.py

* fix

* fix code style and recover third_party

* add v grad check

* add scalar value case

* fix test case

* use logical_not

* fix doc style

* Update manipulation.py
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* add masked_fill for paddle

* update doc

* update some test case

* remove full_like

* update test codes

* update test cases

* recover codes

* update test codes

* fix gradients error

* update test codes

* fix

* add bf16 test cases

* update code-block

* update code-block

* update test codes

* Update __init__.py

* fix

* fix code style and recover third_party

* add v grad check

* add scalar value case

* fix test case

* use logical_not

* fix doc style

* Update manipulation.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants