diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index 9d475c323991b..a1c8266842087 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,7 +79,7 @@ def test_api_error(self): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) - y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 2, 4)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)