From 2e79b815f3491205b037c2fb68b6544ce50dc7e9 Mon Sep 17 00:00:00 2001 From: baocheny Date: Sun, 26 Mar 2023 14:04:00 +0800 Subject: [PATCH 1/4] simplify code --- .../fluid/tests/unittests/test_cast_op.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index b8bb13d24bd98..e484155240ea2 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -29,27 +29,19 @@ def convert_to_dtype_(dtype): - if dtype == 5: - return core.VarDesc.VarType.FP32 - elif dtype == 6: - return core.VarDesc.VarType.FP64 - elif dtype == 4: - return core.VarDesc.VarType.FP16 - elif dtype == 2: - return core.VarDesc.VarType.INT32 - elif dtype == 1: - return core.VarDesc.VarType.INT16 - elif dtype == 3: - return core.VarDesc.VarType.INT64 - elif dtype == 0: - return core.VarDesc.VarType.BOOL - elif dtype == 22: - return core.VarDesc.VarType.BF16 - elif dtype == 20: - return core.VarDesc.VarType.UINT8 - elif dtype == 21: - return core.VarDesc.VarType.INT8 - elif dtype == np.complex64: + _dtype_map = { + 5: core.VarDesc.VarType.FP32, + 6: core.VarDesc.VarType.FP64, + 4: core.VarDesc.VarType.FP16, + 2: core.VarDesc.VarType.INT32, + 1: core.VarDesc.VarType.INT16, + 3: core.VarDesc.VarType.INT64, + 0: core.VarDesc.VarType.BOOL, + 22: core.VarDesc.VarType.BF16, + } + if dtype in _dtype_map.keys(): + return _dtype_map[dtype] + else: raise ValueError("Not supported dtype %s" % dtype) From 65dc9aa6e1c1c0e26e68905e8b83df97404aeda8 Mon Sep 17 00:00:00 2001 From: Kim Date: Sun, 26 Mar 2023 19:43:02 +0800 Subject: [PATCH 2/4] Update test_cast_op.py --- python/paddle/fluid/tests/unittests/test_cast_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index e484155240ea2..a63adcb71b142 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -38,6 +38,8 @@ def convert_to_dtype_(dtype): 3: core.VarDesc.VarType.INT64, 0: core.VarDesc.VarType.BOOL, 22: core.VarDesc.VarType.BF16, + 20: core.VarDesc.VarType.UINT8, + 21: core.VarDesc.VarType.INT8 } if dtype in _dtype_map.keys(): return _dtype_map[dtype] From a70f47fbb4a5417d805b07bbf30663ed7a0c62b1 Mon Sep 17 00:00:00 2001 From: Kim Date: Sun, 26 Mar 2023 20:16:44 +0800 Subject: [PATCH 3/4] Update test_cast_op.py --- .../paddle/fluid/tests/unittests/test_cast_op.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index a63adcb71b142..e61af5df02a68 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -30,16 +30,16 @@ def convert_to_dtype_(dtype): _dtype_map = { - 5: core.VarDesc.VarType.FP32, - 6: core.VarDesc.VarType.FP64, - 4: core.VarDesc.VarType.FP16, - 2: core.VarDesc.VarType.INT32, + 0: core.VarDesc.VarType.BOOL, 1: core.VarDesc.VarType.INT16, + 2: core.VarDesc.VarType.INT32, 3: core.VarDesc.VarType.INT64, - 0: core.VarDesc.VarType.BOOL, - 22: core.VarDesc.VarType.BF16, + 4: core.VarDesc.VarType.FP16, + 5: core.VarDesc.VarType.FP32, + 6: core.VarDesc.VarType.FP64, 20: core.VarDesc.VarType.UINT8, - 21: core.VarDesc.VarType.INT8 + 21: core.VarDesc.VarType.INT8, + 22: core.VarDesc.VarType.BF16 } if dtype in _dtype_map.keys(): return _dtype_map[dtype] From 3539c204552575ab77925788f80d7a1fe35dc28a Mon Sep 17 00:00:00 2001 From: baocheny Date: Sun, 26 Mar 2023 23:29:53 +0800 Subject: [PATCH 4/4] make lint happy --- python/paddle/fluid/tests/unittests/test_cast_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index e61af5df02a68..bfe757fbe53f5 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -39,7 +39,7 @@ def convert_to_dtype_(dtype): 6: core.VarDesc.VarType.FP64, 20: core.VarDesc.VarType.UINT8, 21: core.VarDesc.VarType.INT8, - 22: core.VarDesc.VarType.BF16 + 22: core.VarDesc.VarType.BF16, } if dtype in _dtype_map.keys(): return _dtype_map[dtype]