11from hypothesis import given
22from hypothesis import strategies as st
33
4+ from array_api_tests .algos import broadcast_shapes
5+ from array_api_tests .test_manipulation_functions import assert_equals as assert_equals_
46from array_api_tests .test_statistical_functions import (
57 assert_equals ,
68 assert_keepdimable_shape ,
1315from . import array_helpers as ah
1416from . import dtype_helpers as dh
1517from . import hypothesis_helpers as hh
18+ from . import pytest_helpers as ph
1619from . import xps
1720
1821
@@ -95,6 +98,7 @@ def test_argmin(x, data):
9598 assert_equals ("argmin" , int , out_idx , min_i , expected )
9699
97100
101+ # TODO: skip if opted out
98102@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes (min_side = 1 )))
99103def test_nonzero (x ):
100104 out = xp .nonzero (x )
@@ -133,7 +137,6 @@ def test_nonzero(x):
133137 ), f"{ f_idx } is in the wrong position, should be { indices .index (idx )} "
134138
135139
136- # TODO: skip if opted out
137140@given (
138141 shapes = hh .mutually_broadcastable_shapes (3 ),
139142 dtypes = hh .mutually_promotable_dtypes (),
@@ -143,5 +146,17 @@ def test_where(shapes, dtypes, data):
143146 cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [0 ]), label = "condition" )
144147 x1 = data .draw (xps .arrays (dtype = dtypes [0 ], shape = shapes [1 ]), label = "x1" )
145148 x2 = data .draw (xps .arrays (dtype = dtypes [1 ], shape = shapes [2 ]), label = "x2" )
146- xp .where (cond , x1 , x2 )
147- # TODO
149+
150+ out = xp .where (cond , x1 , x2 )
151+
152+ shape = broadcast_shapes (* shapes )
153+ ph .assert_shape ("where" , out .shape , shape )
154+ # TODO: generate indices without broadcasting arrays
155+ _cond = xp .broadcast_to (cond , shape )
156+ _x1 = xp .broadcast_to (x1 , shape )
157+ _x2 = xp .broadcast_to (x2 , shape )
158+ for idx in ah .ndindex (shape ):
159+ if _cond [idx ]:
160+ assert_equals_ ("where" , f"_x1[{ idx } ]" , _x1 [idx ], f"out[{ idx } ]" , out [idx ])
161+ else :
162+ assert_equals_ ("where" , f"_x2[{ idx } ]" , _x2 [idx ], f"out[{ idx } ]" , out [idx ])
0 commit comments