Skip to content

Commit

Permalink
[PIR] A-17 Adapt transpose test_errors (#61119)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* CI
  • Loading branch information
enkilee authored Feb 19, 2024
1 parent fd8eaa4 commit cb53825
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
6 changes: 5 additions & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def transpose(x, perm, name=None):
perm[i]-th dimension of `input`.
Args:
x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float32, float64, int32.
x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float16, bfloat16, float32, float64, int8, int16, int32, int64, uint8, uint16, complex64, complex128.
perm (list|tuple): Permute the input according to the data of perm.
name (str, optional): The name of this layer. For more information, please refer to :ref:`api_guide_Name`. Default is None.
Expand Down Expand Up @@ -119,8 +119,12 @@ def transpose(x, perm, name=None):
[
'bool',
'float16',
'bfloat16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'int32',
'int64',
'uint16',
Expand Down
16 changes: 5 additions & 11 deletions test/legacy_test/test_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()
Expand Down Expand Up @@ -499,9 +499,12 @@ def initTestCase(self):


class TestTransposeOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
name='x', shape=[-1, 10, 5, 3], dtype='float64'
)
Expand All @@ -512,15 +515,6 @@ def test_x_Variable_check():

self.assertRaises(TypeError, test_x_Variable_check)

def test_x_dtype_check():
# the Input(x)'s dtype must be one of [bool, float16, float32, float64, int32, int64]
x1 = paddle.static.data(
name='x1', shape=[-1, 10, 5, 3], dtype='int8'
)
paddle.transpose(x1, perm=[1, 0, 2])

self.assertRaises(TypeError, test_x_dtype_check)

def test_perm_list_check():
# Input(perm)'s type must be list
paddle.transpose(x, perm="[1, 0, 2]")
Expand Down

0 comments on commit cb53825

Please sign in to comment.