Skip to content

Commit b27be74

Browse files
committed
fix some dtype convert
1 parent e0e290a commit b27be74

File tree

5 files changed

+46
-17
lines changed

5 files changed

+46
-17
lines changed

python/paddle/amp/auto_cast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def amp_guard(
529529
raise ValueError("level should be O0, OD, O1 or O2.")
530530

531531
# check amp_dtype: float16 or bfloat16
532+
if isinstance(dtype, paddle.base.core.DataType):
533+
dtype = dtype.name
532534
dtype = dtype.lower()
533535
if enable:
534536
if dtype not in ['float16', 'bfloat16']:

python/paddle/nn/functional/activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,8 +1753,9 @@ def log_softmax(
17531753
[-12.31326640, -1.31326640 , -0.31326640 , -15.31326640],
17541754
[-3.44018970 , -2.44018970 , -1.44018970 , -0.44018970 ]]])
17551755
"""
1756-
1757-
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
1756+
if dtype is None:
1757+
dtype = x.dtype
1758+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
17581759
dtype = convert_np_dtype_to_dtype_(dtype)
17591760

17601761
if in_dynamic_or_pir_mode():

python/paddle/sparse/unary.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,13 @@ def cast(
623623
assert (
624624
in_dynamic_or_pir_mode()
625625
), "Currently, Sparse API only support dynamic mode or pir mode."
626-
if index_dtype and not isinstance(index_dtype, core.VarDesc.VarType):
626+
if index_dtype and not isinstance(
627+
index_dtype, (core.VarDesc.VarType, core.DataType)
628+
):
627629
index_dtype = convert_np_dtype_to_dtype_(index_dtype)
628-
if value_dtype and not isinstance(value_dtype, core.VarDesc.VarType):
630+
if value_dtype and not isinstance(
631+
value_dtype, (core.VarDesc.VarType, core.DataType)
632+
):
629633
value_dtype = convert_np_dtype_to_dtype_(value_dtype)
630634
return _C_ops.sparse_cast(x, index_dtype, value_dtype)
631635

python/paddle/tensor/creation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3274,9 +3274,8 @@ def tril_indices(
32743274
[[1, 2, 2, 3, 3, 3],
32753275
[0, 0, 1, 0, 1, 2]])
32763276
"""
3277-
if not isinstance(dtype, core.VarDesc.VarType):
3277+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
32783278
dtype = convert_np_dtype_to_dtype_(dtype)
3279-
32803279
if not isinstance(row, int) or row < 0:
32813280
raise TypeError("row should be a non-negative int")
32823281

@@ -3355,7 +3354,8 @@ def triu_indices(
33553354
[[0 0 0 0 1 1 1 1 2 2 2 3 3]
33563355
[0 1 2 3 0 1 2 3 1 2 3 2 3]]
33573356
"""
3358-
if not isinstance(dtype, core.VarDesc.VarType):
3357+
3358+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
33593359
dtype = convert_np_dtype_to_dtype_(dtype)
33603360

33613361
if not isinstance(row, int) or row < 0:

python/paddle/tensor/math.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4348,7 +4348,11 @@ def cumsum(
43484348
flatten = True
43494349
else:
43504350
flatten = False
4351-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4351+
if dtype is None:
4352+
dtype = x.dtype
4353+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
4354+
dtype = convert_np_dtype_to_dtype_(dtype)
4355+
if dtype is not None and x.dtype != dtype:
43524356
x = cast(x, dtype)
43534357

43544358
if in_dynamic_or_pir_mode():
@@ -4396,7 +4400,11 @@ def cumsum_(
43964400
flatten = True
43974401
else:
43984402
flatten = False
4399-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4403+
if dtype is None:
4404+
dtype = x.dtype
4405+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
4406+
dtype = convert_np_dtype_to_dtype_(dtype)
4407+
if dtype is not None and x.dtype != dtype:
44004408
x = cast_(x, dtype)
44014409

44024410
if in_dynamic_mode():
@@ -4652,8 +4660,12 @@ def logcumsumexp(
46524660
flatten = True
46534661
else:
46544662
flatten = False
4655-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4656-
x = cast(x, dtype)
4663+
if dtype is None:
4664+
dtype = x.dtype
4665+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
4666+
dtype = convert_np_dtype_to_dtype_(dtype)
4667+
if dtype is not None and x.dtype != dtype:
4668+
x = cast_(x, dtype)
46574669

46584670
if in_dynamic_or_pir_mode():
46594671
if axis is None:
@@ -4753,9 +4765,12 @@ def cumprod(
47534765
if dim is None:
47544766
dim = -1
47554767
x = x.flatten(0, len(x.shape) - 1)
4756-
4757-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4758-
x = cast(x, dtype)
4768+
if dtype is None:
4769+
dtype = x.dtype
4770+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
4771+
dtype = convert_np_dtype_to_dtype_(dtype)
4772+
if dtype is not None and x.dtype != dtype:
4773+
x = cast_(x, dtype)
47594774

47604775
if in_dynamic_or_pir_mode():
47614776
return _C_ops.cumprod(x, dim, False, False)
@@ -4802,8 +4817,11 @@ def cumprod_(
48024817
if dim is None:
48034818
dim = -1
48044819
x = _C_ops.flatten_(x, 0, len(x.shape) - 1)
4805-
4806-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4820+
if dtype is None:
4821+
dtype = x.dtype
4822+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
4823+
dtype = convert_np_dtype_to_dtype_(dtype)
4824+
if dtype is not None and x.dtype != dtype:
48074825
x = cast_(x, dtype)
48084826

48094827
if in_dynamic_mode():
@@ -5031,14 +5049,18 @@ def prod(
50315049
[24. , 1680.])
50325050
50335051
"""
5052+
if dtype is None:
5053+
dtype = x.dtype
50345054
if dtype is not None:
5055+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
5056+
dtype = convert_np_dtype_to_dtype_(dtype)
50355057
check_dtype(
50365058
dtype,
50375059
'dtype',
50385060
['float32', 'float64', 'int32', 'int64', "float16", "uint16"],
50395061
'prod',
50405062
)
5041-
if x.dtype != convert_np_dtype_to_dtype_(dtype):
5063+
if x.dtype != dtype:
50425064
x = cast(x, dtype)
50435065

50445066
# axis is 0-size tensor.

0 commit comments

Comments
 (0)