Skip to content

Commit

Permalink
[PIR]Open uts for PReLU (#60645)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Jan 10, 2024
1 parent 2d9d46a commit c9e0afd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/legacy_test/test_prelu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_dygraph_api(self):
self.dygraph_check(self.weight_np_0)
self.dygraph_check(self.weight_np_1)

@test_with_pir_api
def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
weight_fp32 = paddle.static.data(
Expand All @@ -93,10 +94,11 @@ def test_error(self):
)
self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32)
# support the input dtype is float16
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
F.prelu(x=x_fp16, weight=weight_fp32)
if core.is_compiled_with_cuda():
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
F.prelu(x=x_fp16, weight=weight_fp32)


class TestNNPReluAPI(unittest.TestCase):
Expand Down

0 comments on commit c9e0afd

Please sign in to comment.