11from hypothesis import given
2+ from hypothesis import strategies as st
3+ from hypothesis .control import assume
24
35from . import _array_module as xp
6+ from . import array_helpers as ah
7+ from . import dtype_helpers as dh
48from . import hypothesis_helpers as hh
9+ from . import pytest_helpers as ph
510from . import xps
11+ from .test_manipulation_functions import assert_equals , axis_ndindex
612
713
814# TODO: generate kwargs
@@ -12,8 +18,52 @@ def test_argsort(x):
1218 # TODO
1319
1420
15- # TODO: generate 0d arrays, generate kwargs
16- @given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes (min_dims = 1 )))
17- def test_sort (x ):
18- xp .sort (x )
19- # TODO
21+ # TODO: Test with signed zeros and NaNs (and ignore them somehow)
22+ @given (
23+ x = xps .arrays (
24+ dtype = xps .scalar_dtypes (),
25+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
26+ elements = {"allow_nan" : False },
27+ ),
28+ data = st .data (),
29+ )
30+ def test_sort (x , data ):
31+ if dh .is_float_dtype (x .dtype ):
32+ assume (not xp .any (x == - 0.0 ) and not xp .any (x == + 0.0 ))
33+
34+ kw = data .draw (
35+ hh .kwargs (
36+ axis = st .integers (- x .ndim , x .ndim - 1 ),
37+ descending = st .booleans (),
38+ stable = st .booleans (),
39+ ),
40+ label = "kw" ,
41+ )
42+
43+ out = xp .sort (x , ** kw )
44+
45+ ph .assert_dtype ("sort" , out .dtype , x .dtype )
46+ ph .assert_shape ("sort" , out .shape , x .shape , ** kw )
47+ axis = kw .get ("axis" , - 1 )
48+ _axis = axis if axis >= 0 else x .ndim + axis
49+ descending = kw .get ("descending" , False )
50+ scalar_type = dh .get_scalar_type (x .dtype )
51+ for idx in axis_ndindex (x .shape , _axis ):
52+ f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
53+ indexed_x = x [idx ]
54+ indexed_out = out [idx ]
55+ out_indices = list (ah .ndindex (indexed_x .shape ))
56+ elements = [scalar_type (indexed_x [idx2 ]) for idx2 in out_indices ]
57+ indices_order = sorted (
58+ range (len (out_indices )), key = elements .__getitem__ , reverse = descending
59+ )
60+ x_indices = [out_indices [o ] for o in indices_order ]
61+ for out_idx , x_idx in zip (out_indices , x_indices ):
62+ assert_equals (
63+ "sort" ,
64+ f"x[{ f_idx } ][{ x_idx } ]" ,
65+ indexed_x [x_idx ],
66+ f"out[{ f_idx } ][{ out_idx } ]" ,
67+ indexed_out [out_idx ],
68+ ** kw ,
69+ )
0 commit comments