Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Jan 17, 2024
1 parent ec4b1b5 commit 135d211
Showing 1 changed file with 64 additions and 44 deletions.
108 changes: 64 additions & 44 deletions test/legacy_test/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,69 +467,89 @@ def test_add_n_and_add_and_grad(self):
class TestRaiseSumError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
def test_type():
paddle.add_n([11, 22])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

self.assertRaises(TypeError, test_type)
def test_type():
paddle.add_n([11, 22])

def test_dtype():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
data2 = paddle.static.data(name="input2", shape=[10], dtype="int8")
paddle.add_n([data1, data2])
self.assertRaises(TypeError, test_type)

self.assertRaises(TypeError, test_dtype)
def test_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="int8"
)
paddle.add_n([data1, data2])

def test_dtype1():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
paddle.add_n(data1)
self.assertRaises(TypeError, test_dtype)

def test_dtype1():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
paddle.add_n(data1)

self.assertRaises(TypeError, test_dtype1)
self.assertRaises(TypeError, test_dtype1)


class TestRaiseSumsError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
def test_type():
paddle.add_n([11, 22])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

self.assertRaises(TypeError, test_type)
def test_type():
paddle.add_n([11, 22])

def test_dtype():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
data2 = paddle.static.data(name="input2", shape=[10], dtype="int8")
paddle.add_n([data1, data2])
self.assertRaises(TypeError, test_type)

self.assertRaises(TypeError, test_dtype)
def test_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="int8"
)
paddle.add_n([data1, data2])

def test_dtype1():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
paddle.add_n(data1)
self.assertRaises(TypeError, test_dtype)

self.assertRaises(TypeError, test_dtype1)
def test_dtype1():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
paddle.add_n(data1)

def test_out_type():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = [10]
out = paddle.add_n([data1, data2])
self.assertRaises(TypeError, test_dtype1)

self.assertRaises(TypeError, test_out_type)
def test_out_type():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = [10]
out = paddle.add_n([data1, data2])

def test_out_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = paddle.static.data(name="out", shape=[10], dtype="int8")
out = paddle.add_n([data1, data2])
self.assertRaises(TypeError, test_out_type)

def test_out_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = paddle.static.data(name="out", shape=[10], dtype="int8")
out = paddle.add_n([data1, data2])

self.assertRaises(TypeError, test_out_dtype)
self.assertRaises(TypeError, test_out_dtype)


class TestSumOpError(unittest.TestCase):
Expand Down

0 comments on commit 135d211

Please sign in to comment.