11import math
2- from typing import Union
2+ from typing import Union , Any , Tuple
33from itertools import takewhile , count
44
55from hypothesis import assume , given , strategies as st
1313from .typing import Shape , DataType , Array , Scalar
1414
1515
16+ @st .composite
17+ def specified_kwargs (draw , * keys_values_defaults : Tuple [str , Any , Any ]):
18+ """Generates valid kwargs given expected defaults.
19+
20+ When we can't realistically use hh.kwargs() and thus test whether xp infact
21+ defaults correctly, this strategy lets us remove generated arguments if they
22+ are of the default value anyway.
23+ """
24+ kw = {}
25+ for key , value , default in keys_values_defaults :
26+ if value is not default or draw (st .booleans ()):
27+ kw [key ] = value
28+ return kw
29+
30+
1631def assert_default_float (func_name : str , dtype : DataType ):
1732 f_dtype = dh .dtype_to_name [dtype ]
1833 f_default = dh .dtype_to_name [dh .default_float ]
@@ -168,7 +183,15 @@ def test_arange(dtype, data):
168183 size <= hh .MAX_ARRAY_SIZE
169184 ), f"{ size = } should be no more than { hh .MAX_ARRAY_SIZE } " # sanity check
170185
171- out = xp .arange (start , stop = stop , step = step , dtype = dtype )
186+ kw = data .draw (
187+ specified_kwargs (
188+ ("stop" , stop , None ),
189+ ("step" , step , None ),
190+ ("dtype" , dtype , None ),
191+ ),
192+ label = "kw" ,
193+ )
194+ out = xp .arange (start , ** kw )
172195
173196 if dtype is None :
174197 if all_int :
@@ -356,15 +379,22 @@ def test_linspace(num, dtype, endpoint, data):
356379 m , M = dh .dtype_ranges [_dtype ]
357380 stop = data .draw (int_stops (start , min_gap , m , M ), label = "stop" )
358381
359- out = xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
382+ kw = data .draw (
383+ specified_kwargs (
384+ ("dtype" , dtype , None ),
385+ ("endpoint" , endpoint , True ),
386+ ),
387+ label = "kw" ,
388+ )
389+ out = xp .linspace (start , stop , num , ** kw )
360390
361391 assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
362392
363393 if endpoint :
364394 if num > 1 :
365395 assert ah .equal (
366396 out [- 1 ], ah .asarray (stop , dtype = out .dtype )
367- ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace()]"
397+ ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace({ start = } , { num = } )]"
368398 else :
369399 # linspace(..., num, endpoint=True) should return an array equivalent to
370400 # the first num elements when endpoint=False
@@ -375,8 +405,9 @@ def test_linspace(num, dtype, endpoint, data):
375405 if num > 0 :
376406 assert ah .equal (
377407 out [0 ], ah .asarray (start , dtype = out .dtype )
378- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace()]"
379- # TODO: array assertions ala test_arange
408+ ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
409+
410+ # TODO: array assertions ala test_arange
380411
381412
382413def make_one (dtype : DataType ) -> Scalar :
0 commit comments