Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_with_pir_api in error test #60693

Merged
merged 8 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/legacy_test/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def init_config(self):


class TestArangeOpError(unittest.TestCase):
@test_with_pir_api
def test_static_errors(self):
with program_guard(Program(), Program()):
paddle.enable_static()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_assign_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def test_assign_LoDTensorArray(self):


class TestAssignOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_cast_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def test_grad(self):


class TestCastOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def test_check_output(self):


class TestCompareOpError(unittest.TestCase):
@test_with_pir_api
def test_int16_support(self):
paddle.enable_static()
with paddle.static.program_guard(
Expand Down
11 changes: 1 addition & 10 deletions test/legacy_test/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_numel_imperative(self):
np.testing.assert_array_equal(out_2.numpy().item(0), np.size(input_2))
paddle.enable_static()

@test_with_pir_api
def test_error(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
Expand All @@ -200,16 +201,6 @@ def test_x_type():

self.assertRaises(TypeError, test_x_type)

def test_pir_error(self):
with paddle.pir_utils.IrGuard():

def test_x_type():
shape = [1, 4, 5]
input_1 = np.random.random(shape).astype("int32")
out_1 = paddle.numel(input_1)

self.assertRaises(TypeError, test_x_type)


if __name__ == '__main__':
paddle.enable_static()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def _set_paddle_api(self):
self.data = paddle.static.data
self.reshape = paddle.reshape

@test_with_pir_api
def _test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_scale_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_scale_selected_rows_inplace(self):


class TestScaleRaiseError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()

Expand Down
110 changes: 66 additions & 44 deletions test/legacy_test/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,69 +465,91 @@ def test_add_n_and_add_and_grad(self):


class TestRaiseSumError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
def test_type():
paddle.add_n([11, 22])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

self.assertRaises(TypeError, test_type)
def test_type():
paddle.add_n([11, 22])

def test_dtype():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
data2 = paddle.static.data(name="input2", shape=[10], dtype="int8")
paddle.add_n([data1, data2])
self.assertRaises(TypeError, test_type)

self.assertRaises(TypeError, test_dtype)
def test_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="int8"
)
paddle.add_n([data1, data2])

def test_dtype1():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
paddle.add_n(data1)
self.assertRaises(TypeError, test_dtype)

def test_dtype1():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
paddle.add_n(data1)

self.assertRaises(TypeError, test_dtype1)
self.assertRaises(TypeError, test_dtype1)


class TestRaiseSumsError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
def test_type():
paddle.add_n([11, 22])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

def test_type():
paddle.add_n([11, 22])

self.assertRaises(TypeError, test_type)
self.assertRaises(TypeError, test_type)

def test_dtype():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
data2 = paddle.static.data(name="input2", shape=[10], dtype="int8")
paddle.add_n([data1, data2])
def test_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="int8"
)
paddle.add_n([data1, data2])

self.assertRaises(TypeError, test_dtype)
self.assertRaises(TypeError, test_dtype)

def test_dtype1():
data1 = paddle.static.data(name="input1", shape=[10], dtype="int8")
paddle.add_n(data1)
def test_dtype1():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="int8"
)
paddle.add_n(data1)

self.assertRaises(TypeError, test_dtype1)
self.assertRaises(TypeError, test_dtype1)

def test_out_type():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = [10]
out = paddle.add_n([data1, data2])
def test_out_type():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = [10]
out = paddle.add_n([data1, data2])

self.assertRaises(TypeError, test_out_type)
self.assertRaises(TypeError, test_out_type)

def test_out_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = paddle.static.data(name="out", shape=[10], dtype="int8")
out = paddle.add_n([data1, data2])
def test_out_dtype():
data1 = paddle.static.data(
name="input1", shape=[10], dtype="flaot32"
)
data2 = paddle.static.data(
name="input2", shape=[10], dtype="float32"
)
out = paddle.static.data(name="out", shape=[10], dtype="int8")
out = paddle.add_n([data1, data2])

self.assertRaises(TypeError, test_out_dtype)
self.assertRaises(TypeError, test_out_dtype)


class TestSumOpError(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_uniform_random_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def init_dtype(self):


class TestUniformRandomOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()
main_prog = Program()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_unsqueeze2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_api(self):
np.testing.assert_array_equal(res_4, input.reshape([3, 2, 5, 1]))
np.testing.assert_array_equal(res_5, input.reshape([3, 1, 1, 2, 5, 1]))

@test_with_pir_api
def test_error(self):
def test_axes_type():
x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_where_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def test_where_condition(self):


class TestWhereOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand Down