Skip to content

Commit

Permalink
[bug fix] fix fp16 dtype checking for cumsum op (#50973)
Browse files Browse the repository at this point in the history
  • Loading branch information
CollaborativeFiltering authored Feb 28, 2023
1 parent ab1b630 commit 4975207
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/paddle/fluid/tests/unittests/test_cumsum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,5 +461,21 @@ def test_static_and_infer(self):
np.testing.assert_allclose(static_out[0], infer_out)


class TestCumSumOpFp16(unittest.TestCase):
def test_fp16(self):
x_np = np.random.random((100, 100)).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[100, 100], name='x', dtype='float16')
y1 = paddle.cumsum(x)
y2 = paddle.cumsum(x, axis=0)
y3 = paddle.cumsum(x, axis=-1)
y4 = paddle.cumsum(x, axis=-2)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[y1, y2, y3, y4])


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3208,6 +3208,12 @@ def cumsum(x, axis=None, dtype=None, name=None):
axis = -1
return _C_ops.cumsum(x, axis, flatten, False, False)
else:
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
'cumsum',
)
check_type(x, 'x', (Variable), 'cumsum')
locals_var = locals().copy()
kwargs = dict()
Expand Down

0 comments on commit 4975207

Please sign in to comment.