|
10 | 10 | from . import dtype_helpers as dh |
11 | 11 | from . import pytest_helpers as ph |
12 | 12 | from . import xps |
13 | | -from .typing import Shape, DataType, Array |
| 13 | +from .typing import Shape, DataType, Array, Scalar |
14 | 14 |
|
15 | 15 |
|
16 | 16 | def assert_default_float(func_name: str, dtype: DataType): |
@@ -43,21 +43,25 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType): |
43 | 43 | assert out_dtype == kw_dtype, msg |
44 | 44 |
|
45 | 45 |
|
46 | | -def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape], **kw): |
| 46 | +def assert_shape( |
| 47 | + func_name: str, out_shape: Shape, expected: Union[int, Shape], /, **kw |
| 48 | +): |
47 | 49 | f_kw = ", ".join(f"{k}={v}" for k, v in kw.items()) |
48 | 50 | msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]" |
49 | 51 | if isinstance(expected, int): |
50 | 52 | expected = (expected,) |
51 | 53 | assert out_shape == expected, msg |
52 | 54 |
|
53 | 55 |
|
54 | | -def assert_fill(func_name: str, fill: float, dtype: DataType, out: Array, **kw): |
| 56 | +def assert_fill( |
| 57 | + func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw |
| 58 | +): |
55 | 59 | f_kw = ", ".join(f"{k}={v}" for k, v in kw.items()) |
56 | | - msg = f"out not filled with {fill} [{func_name}({f_kw})]\n" f"{out=}" |
57 | | - if math.isnan(fill): |
| 60 | + msg = f"out not filled with {fill_value} [{func_name}({f_kw})]\n" f"{out=}" |
| 61 | + if math.isnan(fill_value): |
58 | 62 | assert ah.all(ah.isnan(out)), msg |
59 | 63 | else: |
60 | | - assert ah.all(ah.equal(out, ah.asarray(fill, dtype=dtype))), msg |
| 64 | + assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg |
61 | 65 |
|
62 | 66 |
|
63 | 67 | # Testing xp.arange() requires bounding the start/stop/step arguments to only |
@@ -375,7 +379,7 @@ def test_linspace(num, dtype, endpoint, data): |
375 | 379 | # TODO: array assertions ala test_arange |
376 | 380 |
|
377 | 381 |
|
378 | | -def make_one(dtype: DataType) -> Union[bool, float]: |
| 382 | +def make_one(dtype: DataType) -> Scalar: |
379 | 383 | if dtype is None or dh.is_float_dtype(dtype): |
380 | 384 | return 1.0 |
381 | 385 | elif dh.is_int_dtype(dtype): |
@@ -411,7 +415,7 @@ def test_ones_like(x, kw): |
411 | 415 | assert_fill("ones_like", make_one(dtype), dtype, out) |
412 | 416 |
|
413 | 417 |
|
414 | | -def make_zero(dtype: DataType) -> Union[bool, float]: |
| 418 | +def make_zero(dtype: DataType) -> Scalar: |
415 | 419 | if dtype is None or dh.is_float_dtype(dtype): |
416 | 420 | return 0.0 |
417 | 421 | elif dh.is_int_dtype(dtype): |
|
0 commit comments