11import unittest
2- import timeit
3- import warnings
42
53import xt
64import numpy as np
@@ -19,11 +17,24 @@ def test_mean(self):
1917
2018 self .assertTrue (np .allclose (n , x ))
2119
22- n = timeit .timeit (lambda : np .mean (a ), number = 10 )
23- x = timeit .timeit (lambda : xt .mean (a ), number = 10 )
20+ def test_average (self ):
2421
25- if x / n > 1.1 :
26- warnings .warn (f"efficiency xt.mean { x / n :.2e} " )
22+ a = np .random .random ([103 , 102 , 101 ])
23+ w = np .random .random ([103 , 102 , 101 ])
24+ n = np .average (a , weights = w )
25+ x = xt .average (a , w )
26+
27+ self .assertTrue (np .allclose (n , x ))
28+
29+ def test_average_axes (self ):
30+
31+ a = np .random .random ([103 , 102 , 101 ])
32+ w = np .random .random ([103 , 102 , 101 ])
33+ axis = int (np .random .randint (0 , high = 3 ))
34+ n = np .average (a , weights = w , axis = (axis ,))
35+ x = xt .average (a , w , [axis ])
36+
37+ self .assertTrue (np .allclose (n , x ))
2738
2839 def test_flip (self ):
2940
@@ -34,12 +45,6 @@ def test_flip(self):
3445
3546 self .assertTrue (np .allclose (n , x ))
3647
37- n = timeit .timeit (lambda : np .flip (a , axis ), number = 10 )
38- x = timeit .timeit (lambda : xt .flip (a , axis ), number = 10 )
39-
40- if x / n > 1.1 :
41- warnings .warn (f"efficiency xt.flip { x / n :.2e} " )
42-
4348 def test_cos (self ):
4449
4550 a = np .random .random ([103 , 102 , 101 ])
@@ -48,13 +53,6 @@ def test_cos(self):
4853
4954 self .assertTrue (np .allclose (n , x ))
5055
51- n = timeit .timeit (lambda : np .cos (a ), number = 10 )
52- x = timeit .timeit (lambda : xt .cos (a ), number = 10 )
53-
54- if x / n > 1.1 :
55- warnings .warn (f"efficiency xt.cos { x / n :.2e} " )
56-
57-
5856
5957if __name__ == "__main__" :
6058
0 commit comments