Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4483,7 +4483,7 @@ def cumprod(
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
x = cast_(x, dtype)
x = cast(x, dtype)

if in_dynamic_or_pir_mode():
return _C_ops.cumprod(x, dim, False, False)
Expand Down Expand Up @@ -4530,9 +4530,7 @@ def cumprod_(
if dim is None:
dim = -1
x = _C_ops.flatten_(x, 0, len(x.shape) - 1)
if dtype is None:
dtype = x.dtype
else:
if dtype is not None:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
Expand Down
34 changes: 26 additions & 8 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,45 @@
)
from paddle.utils.decorator_utils import (
ParamAliasDecorator,
param_two_alias,
param_two_alias_one_default,
)

from ..base.data_feeder import check_type, check_variable_and_dtype
from ..common_ops_import import Variable
from ..framework import (
LayerHelper,
core,
)
from ..framework import LayerHelper, convert_np_dtype_to_dtype_, core
from .manipulation import cast
from .math import _get_reduce_axis_with_tensor

if TYPE_CHECKING:
from collections.abc import Sequence

from paddle import Tensor
from paddle._typing import DTypeLike

_Interpolation: TypeAlias = Literal[
'linear', 'higher', 'lower', 'midpoint', 'nearest'
]
__all__ = []


@param_two_alias(["x", "input"], ["axis", "dim"])
def mean(
x: Tensor,
axis: int | Sequence[int] | None = None,
keepdim: bool = False,
name: str | None = None,
*,
dtype: DTypeLike | None = None,
out: Tensor | None = None,
) -> Tensor:
"""
Computes the mean of the input tensor's elements along ``axis``.

Args:
x (Tensor): The input Tensor with data type bool, bfloat16, float16, float32,
float64, int32, int64, complex64, complex128.
alias: ``input``
axis (int|list|tuple|None, optional): The axis along which to perform mean
calculations. ``axis`` should be int, list(int) or tuple(int). If
``axis`` is a list/tuple of dimension(s), mean is calculated along
Expand All @@ -69,13 +74,16 @@ def mean(
``axis`` or element(s) of ``axis`` is less than 0, it works the
same way as :math:`axis + D` . If ``axis`` is None, mean is
calculated over all elements of ``x``. Default is None.
alias: ``dim``
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
dtype (str): The desired data type of returned tensor. Default: None.
out(Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor, results of average along ``axis`` of ``x``, with the same data
Expand Down Expand Up @@ -110,9 +118,19 @@ def mean(
>>> out4 = paddle.mean(x, axis=[0, 2])
>>> print(out4.numpy())
[ 8.5 12.5 16.5]
>>> out5 = paddle.mean(x, dtype='float64')
>>> out5
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
12.50000000)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只用处理 dtype != None的情况,减少非必要的判断

if dtype is not None:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
x = cast(x, dtype)

if in_dynamic_or_pir_mode():
return _C_ops.mean(x, axis, keepdim)
return _C_ops.mean(x, axis, keepdim, out=out)
else:
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
check_variable_and_dtype(
Expand Down Expand Up @@ -146,14 +164,14 @@ def mean(
helper = LayerHelper('mean', **locals())

attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
out = helper.create_variable_for_type_inference(x.dtype)
out_tensor = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='reduce_mean',
inputs={'X': x},
outputs={'Out': out},
outputs={'Out': out_tensor},
attrs=attrs,
)
return out
return out_tensor


@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
Expand Down
Loading
Loading