2020from paddle import base
2121
2222
23+ def skip_if_xpu_or_onednn_and_not_float32 (dtype ):
24+ """Skip test if using XPU or OneDNN and dtype is not float32"""
25+
26+ def decorator (test_func ):
27+ def wrapper (self ):
28+ # Check if we're using XPU
29+ is_xpu = (hasattr (self , 'use_xpu' ) and self .use_xpu ) or (
30+ paddle .device .get_device ().startswith ('xpu' )
31+ )
32+
33+ # Check if we're using OneDNN
34+ is_onednn = base .core .globals ().get ("FLAGS_use_onednn" , False ) or (
35+ hasattr (self , 'use_onednn' ) and self .use_onednn
36+ )
37+
38+ # Skip if using XPU or OneDNN and dtype is not float32
39+ if (is_xpu or is_onednn ) and dtype != 'float32' :
40+ self .skipTest (
41+ f"Skip { dtype } test for XPU/OneDNN, only test float32"
42+ )
43+
44+ return test_func (self )
45+
46+ return wrapper
47+
48+ return decorator
49+
50+
2351class TestMeanDtypeParameter (unittest .TestCase ):
2452 def setUp (self ):
2553 paddle .disable_static ()
@@ -28,16 +56,12 @@ def setUp(self):
2856 def tearDown (self ):
2957 paddle .enable_static ()
3058
31- def test_dtype_float16 (self ):
32- x = paddle .to_tensor (self .x_data )
33- result = paddle .mean (x , dtype = 'float16' )
34- self .assertEqual (result .dtype , paddle .float16 )
35-
3659 def test_dtype_float32 (self ):
3760 x = paddle .to_tensor (self .x_data )
3861 result = paddle .mean (x , dtype = 'float32' )
3962 self .assertEqual (result .dtype , paddle .float32 )
4063
64+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
4165 def test_dtype_float64 (self ):
4266 x = paddle .to_tensor (self .x_data )
4367 result = paddle .mean (x , dtype = 'float64' )
@@ -50,18 +74,13 @@ def test_dtype_none_default(self):
5074 self .assertEqual (result1 .dtype , result2 .dtype )
5175 np .testing .assert_allclose (result1 .numpy (), result2 .numpy (), rtol = 1e-05 )
5276
77+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
5378 def test_dtype_with_axis (self ):
5479 x = paddle .to_tensor (self .x_data )
5580 result = paddle .mean (x , axis = 1 , dtype = 'float64' )
5681 self .assertEqual (result .dtype , paddle .float64 )
5782 self .assertEqual (result .shape , [3 , 5 ])
5883
59- def test_dtype_with_keepdim (self ):
60- x = paddle .to_tensor (self .x_data )
61- result = paddle .mean (x , axis = 0 , keepdim = True , dtype = 'float16' )
62- self .assertEqual (result .dtype , paddle .float16 )
63- self .assertEqual (result .shape , [1 , 4 , 5 ])
64-
6584
6685class TestMeanOutParameter (unittest .TestCase ):
6786 def setUp (self ):
@@ -115,6 +134,7 @@ def setUp(self):
115134 def tearDown (self ):
116135 paddle .enable_static ()
117136
137+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
118138 def test_dtype_and_out_compatible (self ):
119139 x = paddle .to_tensor (self .x_data )
120140 out = paddle .empty ([], dtype = 'float64' )
@@ -124,15 +144,6 @@ def test_dtype_and_out_compatible(self):
124144 self .assertEqual (result .dtype , paddle .float64 )
125145 self .assertTrue (paddle .allclose (out , result ))
126146
127- def test_dtype_and_out_with_axis (self ):
128- x = paddle .to_tensor (self .x_data )
129- out = paddle .empty ([2 , 4 ], dtype = 'float16' )
130- result = paddle .mean (x , axis = 1 , dtype = 'float16' , out = out )
131-
132- self .assertEqual (out .dtype , paddle .float16 )
133- self .assertEqual (result .dtype , paddle .float16 )
134- self .assertEqual (out .shape , [2 , 4 ])
135-
136147 def test_dtype_and_out_with_keepdim (self ):
137148 x = paddle .to_tensor (self .x_data )
138149 out = paddle .empty ([2 , 1 , 4 ], dtype = 'float32' )
@@ -173,6 +184,7 @@ def test_multiple_axis_alias(self):
173184
174185 np .testing .assert_allclose (result1 .numpy (), result2 .numpy (), rtol = 1e-05 )
175186
187+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
176188 def test_alias_with_dtype_and_out (self ):
177189 x = paddle .to_tensor (self .x_data )
178190 out1 = paddle .empty ([4 ], dtype = 'float64' )
@@ -186,6 +198,7 @@ def test_alias_with_dtype_and_out(self):
186198
187199
188200class TestMeanNewParametersStatic (unittest .TestCase ):
201+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
189202 def test_static_dtype_parameter (self ):
190203 paddle .enable_static ()
191204 main_prog = paddle .static .Program ()
@@ -245,6 +258,7 @@ def test_dtype_with_int_input(self):
245258 expected = 3.5
246259 np .testing .assert_allclose (result .numpy (), expected , rtol = 1e-05 )
247260
261+ @skip_if_xpu_or_onednn_and_not_float32 ('float64' )
248262 def test_all_parameters_combination (self ):
249263 # Test all new parameters together
250264 x_data = np .random .rand (2 , 3 , 4 ).astype ('float32' )
0 commit comments