@@ -274,29 +274,45 @@ def test_reshape(x, data):
274274
275275@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ()), st .data ())
276276def test_roll (x , data ):
277- shift = data .draw (
278- st .integers () | st .lists (st .integers (), max_size = x .ndim ).map (tuple ),
279- label = "shift" ,
280- )
281- axis_strats = [st .none ()]
282- if x .shape != ():
283- axis_strats .append (st .integers (- x .ndim , x .ndim - 1 ))
284- if isinstance (shift , int ):
285- axis_strats .append (xps .valid_tuple_axes (x .ndim ))
286- kw = data .draw (hh .kwargs (axis = st .one_of (axis_strats )), label = "kw" )
277+ shift_strat = st .integers (- hh .MAX_ARRAY_SIZE , hh .MAX_ARRAY_SIZE )
278+ if x .ndim > 0 :
279+ shift_strat = shift_strat | st .lists (
280+ shift_strat , min_size = 1 , max_size = x .ndim
281+ ).map (tuple )
282+ shift = data .draw (shift_strat , label = "shift" )
283+ if isinstance (shift , tuple ):
284+ axis_strat = xps .valid_tuple_axes (x .ndim ).filter (lambda t : len (t ) == len (shift ))
285+ kw_strat = axis_strat .map (lambda t : {"axis" : t })
286+ else :
287+ axis_strat = st .none ()
288+ if x .ndim != 0 :
289+ axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
290+ kw_strat = hh .kwargs (axis = axis_strat )
291+ kw = data .draw (kw_strat , label = "kw" )
287292
288293 out = xp .roll (x , shift , ** kw )
289294
290295 ph .assert_dtype ("roll" , x .dtype , out .dtype )
291296
292297 ph .assert_result_shape ("roll" , (x .shape ,), out .shape )
293298
294- # TODO: test all shift/axis scenarios
295- if isinstance (shift , int ) and kw . get ( "axis" , None ) is None :
299+ if kw . get ( "axis" , None ) is None :
300+ assert isinstance (shift , int ) # sanity check
296301 indices = list (ah .ndindex (x .shape ))
297302 shifted_indices = deque (indices )
298303 shifted_indices .rotate (- shift )
299304 assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
305+ else :
306+ _shift = (shift ,) if isinstance (shift , int ) else shift
307+ axes = normalise_axis (kw ["axis" ], x .ndim )
308+ all_indices = list (ah .ndindex (x .shape ))
309+ for s , a in zip (_shift , axes ):
310+ side = x .shape [a ]
311+ for i in range (side ):
312+ indices = [idx for idx in all_indices if idx [a ] == i ]
313+ shifted_indices = deque (indices )
314+ shifted_indices .rotate (- s )
315+ assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
300316
301317
302318@given (
0 commit comments