|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import pytest |
6 | | -from hypothesis import assume, given |
7 | | -from hypothesis import strategies as st |
8 | 6 |
|
9 | | -from .. import _array_module as xp |
10 | | -from .. import dtype_helpers as dh |
11 | | -from .. import hypothesis_helpers as hh |
12 | | -from .. import pytest_helpers as ph |
13 | | -from .._array_module import _UndefinedStub |
14 | 7 | from ..algos import BroadcastError, _broadcast_shapes |
15 | | -from ..function_stubs import elementwise_functions |
16 | 8 |
|
17 | 9 |
|
18 | 10 | @pytest.mark.parametrize( |
@@ -41,33 +33,3 @@ def test_broadcast_shapes(shape1, shape2, expected): |
41 | 33 | def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2): |
42 | 34 | with pytest.raises(BroadcastError): |
43 | 35 | _broadcast_shapes(shape1, shape2) |
44 | | - |
45 | | - |
46 | | -# TODO: Extend this to all functions (not just elementwise), and handle |
47 | | -# functions that take more than 2 args |
48 | | -@pytest.mark.parametrize( |
49 | | - "func_name", [i for i in elementwise_functions.__all__ if ph.nargs(i) > 1] |
50 | | -) |
51 | | -@given(shape1=hh.shapes(), shape2=hh.shapes(), data=st.data()) |
52 | | -def test_broadcasting_hypothesis(func_name, shape1, shape2, data): |
53 | | - dtype = data.draw(st.sampled_from(dh.func_in_dtypes[func_name]), label="dtype") |
54 | | - if hh.FILTER_UNDEFINED_DTYPES: |
55 | | - assume(not isinstance(dtype, _UndefinedStub)) |
56 | | - func = getattr(xp, func_name) |
57 | | - if isinstance(func, xp._UndefinedStub): |
58 | | - func._raise() |
59 | | - args = [xp.ones(shape1, dtype=dtype), xp.ones(shape2, dtype=dtype)] |
60 | | - try: |
61 | | - broadcast_shape = _broadcast_shapes(shape1, shape2) |
62 | | - except BroadcastError: |
63 | | - ph.raises( |
64 | | - Exception, |
65 | | - lambda: func(*args), |
66 | | - f"{func_name} should raise an exception from not being able to broadcast inputs with hh.shapes {(shape1, shape2)}", |
67 | | - ) |
68 | | - else: |
69 | | - result = ph.doesnt_raise( |
70 | | - lambda: func(*args), |
71 | | - f"{func_name} raised an unexpected exception from broadcastable inputs with hh.shapes {(shape1, shape2)}", |
72 | | - ) |
73 | | - assert result.shape == broadcast_shape, "broadcast hh.shapes incorrect" |
0 commit comments