@@ -51,13 +51,9 @@ def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape],
5151 assert out_shape == expected , msg
5252
5353
54-
5554def assert_fill (func_name : str , fill : float , dtype : DataType , out : Array , ** kw ):
5655 f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
57- msg = (
58- f"out not filled with { fill } [{ func_name } ({ f_kw } )]\n "
59- f"{ out = } "
60- )
56+ msg = f"out not filled with { fill } [{ func_name } ({ f_kw } )]\n " f"{ out = } "
6157 if math .isnan (fill ):
6258 assert ah .all (ah .isnan (out )), msg
6359 else :
@@ -96,7 +92,7 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]
9692 # in https://github.com/HypothesisWorks/hypothesis/issues/2907
9793 st .floats (min_value , max_value , allow_nan = False , allow_infinity = False ).filter (
9894 lambda n : float_min <= n <= float_max
99- )
95+ ),
10096 )
10197
10298
@@ -118,9 +114,9 @@ def test_arange(start, dtype, data):
118114 step = data .draw (
119115 st .one_of (
120116 reals (min_value = tol ).filter (lambda n : n != 0 ),
121- reals (max_value = - tol ).filter (lambda n : n != 0 )
117+ reals (max_value = - tol ).filter (lambda n : n != 0 ),
122118 ),
123- label = "step"
119+ label = "step" ,
124120 )
125121
126122 all_int = all (arg is None or isinstance (arg , int ) for arg in [start , stop , step ])
@@ -147,11 +143,15 @@ def test_arange(start, dtype, data):
147143 else :
148144 condition = lambda x : x >= _stop
149145 scalar_type = int if dh .is_int_dtype (_dtype ) else float
150- elements = list (scalar_type (n ) for n in takewhile (condition , count (_start , step )))
146+ elements = list (
147+ scalar_type (n ) for n in takewhile (condition , count (_start , step ))
148+ )
151149 else :
152150 elements = []
153151 size = len (elements )
154- assert size <= hh .MAX_ARRAY_SIZE , f"{ size = } , should be no more than { hh .MAX_ARRAY_SIZE = } "
152+ assert (
153+ size <= hh .MAX_ARRAY_SIZE
154+ ), f"{ size = } , should be no more than { hh .MAX_ARRAY_SIZE = } "
155155
156156 out = xp .arange (start , stop = stop , step = step , dtype = dtype )
157157
@@ -178,7 +178,8 @@ def test_arange(start, dtype, data):
178178 if dh .is_int_dtype (_dtype ):
179179 ah .assert_exactly_equal (out , ah .asarray (elements , dtype = _dtype ))
180180 else :
181- pass # TODO: either emulate array module behaviour or assert a rough equals
181+ pass # TODO: either emulate array module behaviour or assert a rough equals
182+
182183
183184@given (hh .shapes (), hh .kwargs (dtype = st .none () | hh .shared_dtypes ))
184185def test_empty (shape , kw ):
@@ -192,7 +193,7 @@ def test_empty(shape, kw):
192193
193194@given (
194195 x = xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes ()),
195- kw = hh .kwargs (dtype = st .none () | xps .scalar_dtypes ())
196+ kw = hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()),
196197)
197198def test_empty_like (x , kw ):
198199 out = xp .empty_like (x , ** kw )
@@ -209,7 +210,7 @@ def test_empty_like(x, kw):
209210 kw = hh .kwargs (
210211 k = st .integers (),
211212 dtype = xps .numeric_dtypes (),
212- )
213+ ),
213214)
214215def test_eye (n_rows , n_cols , kw ):
215216 out = xp .eye (n_rows , n_cols , ** kw )
@@ -237,10 +238,11 @@ def test_eye(n_rows, n_cols, kw):
237238)
238239
239240
240-
241241@st .composite
242242def full_fill_values (draw ) -> st .SearchStrategy [float ]:
243- kw = draw (st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_kw" ))
243+ kw = draw (
244+ st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_kw" )
245+ )
244246 dtype = kw .get ("dtype" , None ) or draw (default_safe_dtypes )
245247 return draw (xps .from_dtype (dtype ))
246248
@@ -262,7 +264,7 @@ def test_full(shape, fill_value, kw):
262264 dtype = dh .default_float
263265 if kw .get ("dtype" , None ) is None :
264266 if isinstance (fill_value , bool ):
265- pass # TODO
267+ pass # TODO
266268 elif isinstance (fill_value , int ):
267269 assert_default_int ("full" , out .dtype )
268270 else :
@@ -275,7 +277,9 @@ def test_full(shape, fill_value, kw):
275277
276278@st .composite
277279def full_like_fill_values (draw ):
278- kw = draw (st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_like_kw" ))
280+ kw = draw (
281+ st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_like_kw" )
282+ )
279283 dtype = kw .get ("dtype" , None ) or draw (hh .shared_dtypes )
280284 return draw (xps .from_dtype (dtype ))
281285
@@ -295,6 +299,7 @@ def test_full_like(x, fill_value, kw):
295299 assert_shape ("full_like" , out .shape , x .shape )
296300 assert_fill ("full_like" , fill_value , dtype , out , fill_value = fill_value )
297301
302+
298303finite_kw = {"allow_nan" : False , "allow_infinity" : False }
299304
300305
@@ -303,10 +308,7 @@ def int_stops(draw, start: int, min_gap: int, m: int, M: int):
303308 sign = draw (st .booleans ().map (int ))
304309 max_gap = abs (M - m )
305310 max_int = math .floor (math .sqrt (max_gap ))
306- gap = draw (
307- st .just (0 ),
308- st .integers (1 , max_int ).map (lambda n : min_gap ** n )
309- )
311+ gap = draw (st .just (0 ) | st .integers (1 , max_int ).map (lambda n : min_gap ** n ))
310312 stop = start + sign * gap
311313 assume (m <= stop <= M )
312314 return stop
0 commit comments