Skip to content

Commit 0be7614

Browse files
committed
update test, enhance performance
1 parent 4e27f92 commit 0be7614

File tree

3 files changed

+47
-31
lines changed

3 files changed

+47
-31
lines changed

python/paddle/tensor/math.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4583,9 +4583,7 @@ def cumprod_(
45834583
if dim is None:
45844584
dim = -1
45854585
x = _C_ops.flatten_(x, 0, len(x.shape) - 1)
4586-
if dtype is None:
4587-
dtype = x.dtype
4588-
else:
4586+
if dtype is not None:
45894587
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
45904588
dtype = convert_np_dtype_to_dtype_(dtype)
45914589
if x.dtype != dtype:

python/paddle/tensor/stat.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..base.data_feeder import check_type, check_variable_and_dtype
3535
from ..common_ops_import import Variable
3636
from ..framework import LayerHelper, convert_np_dtype_to_dtype_, core
37+
from .manipulation import cast_
3738
from .math import _get_reduce_axis_with_tensor
3839

3940
if TYPE_CHECKING:
@@ -64,6 +65,7 @@ def mean(
6465
Args:
6566
x (Tensor): The input Tensor with data type bool, bfloat16, float16, float32,
6667
float64, int32, int64, complex64, complex128.
68+
alias: ``input``
6769
axis (int|list|tuple|None, optional): The axis along which to perform mean
6870
calculations. ``axis`` should be int, list(int) or tuple(int). If
6971
``axis`` is a list/tuple of dimension(s), mean is calculated along
@@ -72,6 +74,7 @@ def mean(
7274
``axis`` or element(s) of ``axis`` is less than 0, it works the
7375
same way as :math:`axis + D` . If ``axis`` is None, mean is
7476
calculated over all elements of ``x``. Default is None.
77+
alias: ``dim``
7578
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
7679
in the output Tensor. If ``keepdim`` is True, the dimensions of
7780
the output Tensor is the same as ``x`` except in the reduced
@@ -115,19 +118,20 @@ def mean(
115118
>>> out4 = paddle.mean(x, axis=[0, 2])
116119
>>> print(out4.numpy())
117120
[ 8.5 12.5 16.5]
121+
>>> out5 = paddle.mean(x, dtype='float64')
122+
>>> out5
123+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
124+
12.50000000)
118125
"""
119-
if dtype is None:
120-
dtype = x.dtype
121-
elif not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
122-
dtype = convert_np_dtype_to_dtype_(dtype)
126+
if dtype is not None:
127+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
128+
dtype = convert_np_dtype_to_dtype_(dtype)
129+
if x.dtype != dtype:
130+
x = cast_(x, dtype)
123131

124132
if in_dynamic_or_pir_mode():
125-
if dtype != x.dtype:
126-
x = x.astype(dtype)
127133
return _C_ops.mean(x, axis, keepdim, out=out)
128134
else:
129-
if dtype != x.dtype:
130-
x = paddle.cast(x, dtype)
131135
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
132136
check_variable_and_dtype(
133137
x,

test/legacy_test/test_mean_op_v1.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,34 @@
2020
from paddle import base
2121

2222

23+
def skip_if_xpu_or_onednn_and_not_float32(dtype):
24+
"""Skip test if using XPU or OneDNN and dtype is not float32"""
25+
26+
def decorator(test_func):
27+
def wrapper(self):
28+
# Check if we're using XPU
29+
is_xpu = (hasattr(self, 'use_xpu') and self.use_xpu) or (
30+
paddle.device.get_device().startswith('xpu')
31+
)
32+
33+
# Check if we're using OneDNN
34+
is_onednn = base.core.globals().get("FLAGS_use_onednn", False) or (
35+
hasattr(self, 'use_onednn') and self.use_onednn
36+
)
37+
38+
# Skip if using XPU or OneDNN and dtype is not float32
39+
if (is_xpu or is_onednn) and dtype != 'float32':
40+
self.skipTest(
41+
f"Skip {dtype} test for XPU/OneDNN, only test float32"
42+
)
43+
44+
return test_func(self)
45+
46+
return wrapper
47+
48+
return decorator
49+
50+
2351
class TestMeanDtypeParameter(unittest.TestCase):
2452
def setUp(self):
2553
paddle.disable_static()
@@ -28,16 +56,12 @@ def setUp(self):
2856
def tearDown(self):
2957
paddle.enable_static()
3058

31-
def test_dtype_float16(self):
32-
x = paddle.to_tensor(self.x_data)
33-
result = paddle.mean(x, dtype='float16')
34-
self.assertEqual(result.dtype, paddle.float16)
35-
3659
def test_dtype_float32(self):
3760
x = paddle.to_tensor(self.x_data)
3861
result = paddle.mean(x, dtype='float32')
3962
self.assertEqual(result.dtype, paddle.float32)
4063

64+
@skip_if_xpu_or_onednn_and_not_float32('float64')
4165
def test_dtype_float64(self):
4266
x = paddle.to_tensor(self.x_data)
4367
result = paddle.mean(x, dtype='float64')
@@ -50,18 +74,13 @@ def test_dtype_none_default(self):
5074
self.assertEqual(result1.dtype, result2.dtype)
5175
np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05)
5276

77+
@skip_if_xpu_or_onednn_and_not_float32('float64')
5378
def test_dtype_with_axis(self):
5479
x = paddle.to_tensor(self.x_data)
5580
result = paddle.mean(x, axis=1, dtype='float64')
5681
self.assertEqual(result.dtype, paddle.float64)
5782
self.assertEqual(result.shape, [3, 5])
5883

59-
def test_dtype_with_keepdim(self):
60-
x = paddle.to_tensor(self.x_data)
61-
result = paddle.mean(x, axis=0, keepdim=True, dtype='float16')
62-
self.assertEqual(result.dtype, paddle.float16)
63-
self.assertEqual(result.shape, [1, 4, 5])
64-
6584

6685
class TestMeanOutParameter(unittest.TestCase):
6786
def setUp(self):
@@ -115,6 +134,7 @@ def setUp(self):
115134
def tearDown(self):
116135
paddle.enable_static()
117136

137+
@skip_if_xpu_or_onednn_and_not_float32('float64')
118138
def test_dtype_and_out_compatible(self):
119139
x = paddle.to_tensor(self.x_data)
120140
out = paddle.empty([], dtype='float64')
@@ -124,15 +144,6 @@ def test_dtype_and_out_compatible(self):
124144
self.assertEqual(result.dtype, paddle.float64)
125145
self.assertTrue(paddle.allclose(out, result))
126146

127-
def test_dtype_and_out_with_axis(self):
128-
x = paddle.to_tensor(self.x_data)
129-
out = paddle.empty([2, 4], dtype='float16')
130-
result = paddle.mean(x, axis=1, dtype='float16', out=out)
131-
132-
self.assertEqual(out.dtype, paddle.float16)
133-
self.assertEqual(result.dtype, paddle.float16)
134-
self.assertEqual(out.shape, [2, 4])
135-
136147
def test_dtype_and_out_with_keepdim(self):
137148
x = paddle.to_tensor(self.x_data)
138149
out = paddle.empty([2, 1, 4], dtype='float32')
@@ -173,6 +184,7 @@ def test_multiple_axis_alias(self):
173184

174185
np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05)
175186

187+
@skip_if_xpu_or_onednn_and_not_float32('float64')
176188
def test_alias_with_dtype_and_out(self):
177189
x = paddle.to_tensor(self.x_data)
178190
out1 = paddle.empty([4], dtype='float64')
@@ -186,6 +198,7 @@ def test_alias_with_dtype_and_out(self):
186198

187199

188200
class TestMeanNewParametersStatic(unittest.TestCase):
201+
@skip_if_xpu_or_onednn_and_not_float32('float64')
189202
def test_static_dtype_parameter(self):
190203
paddle.enable_static()
191204
main_prog = paddle.static.Program()
@@ -245,6 +258,7 @@ def test_dtype_with_int_input(self):
245258
expected = 3.5
246259
np.testing.assert_allclose(result.numpy(), expected, rtol=1e-05)
247260

261+
@skip_if_xpu_or_onednn_and_not_float32('float64')
248262
def test_all_parameters_combination(self):
249263
# Test all new parameters together
250264
x_data = np.random.rand(2, 3, 4).astype('float32')

0 commit comments

Comments
 (0)