Skip to content

Commit de2afce

Browse files
committed
add backward tests, use cast for non-inplace op
1 parent 0be7614 commit de2afce

File tree

4 files changed

+469
-284
lines changed

4 files changed

+469
-284
lines changed

python/paddle/tensor/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4536,7 +4536,7 @@ def cumprod(
45364536
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
45374537
dtype = convert_np_dtype_to_dtype_(dtype)
45384538
if x.dtype != dtype:
4539-
x = cast_(x, dtype)
4539+
x = cast(x, dtype)
45404540

45414541
if in_dynamic_or_pir_mode():
45424542
return _C_ops.cumprod(x, dim, False, False)

python/paddle/tensor/stat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..base.data_feeder import check_type, check_variable_and_dtype
3535
from ..common_ops_import import Variable
3636
from ..framework import LayerHelper, convert_np_dtype_to_dtype_, core
37-
from .manipulation import cast_
37+
from .manipulation import cast
3838
from .math import _get_reduce_axis_with_tensor
3939

4040
if TYPE_CHECKING:
@@ -127,7 +127,7 @@ def mean(
127127
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
128128
dtype = convert_np_dtype_to_dtype_(dtype)
129129
if x.dtype != dtype:
130-
x = cast_(x, dtype)
130+
x = cast(x, dtype)
131131

132132
if in_dynamic_or_pir_mode():
133133
return _C_ops.mean(x, axis, keepdim, out=out)

0 commit comments

Comments
 (0)