|
2 | 2 | from operator import mul |
3 | 3 | from math import sqrt |
4 | 4 | import itertools |
5 | | -from typing import Tuple |
| 5 | +from typing import Tuple, Optional |
6 | 6 |
|
7 | 7 | from hypothesis import assume |
8 | 8 | from hypothesis.strategies import (lists, integers, sampled_from, |
9 | 9 | shared, floats, just, composite, one_of, |
10 | | - none, booleans) |
11 | | -from hypothesis.strategies._internal.strategies import SearchStrategy |
| 10 | + none, booleans, SearchStrategy) |
12 | 11 |
|
13 | 12 | from .pytest_helpers import nargs |
14 | 13 | from .array_helpers import ndindex |
@@ -77,10 +76,34 @@ def _dtypes_sorter(dtype_pair): |
77 | 76 | ] |
78 | 77 |
|
79 | 78 |
|
80 | | -def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes): |
81 | | - return sampled_from( |
82 | | - [(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs] |
83 | | - ) |
| 79 | +def mutually_promotable_dtypes( |
| 80 | + max_size: Optional[int] = 2, |
| 81 | + *, |
| 82 | + dtypes=dh.all_dtypes, |
| 83 | +) -> SearchStrategy[Tuple]: |
| 84 | + if max_size == 2: |
| 85 | + return sampled_from( |
| 86 | + [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes] |
| 87 | + ) |
| 88 | + if isinstance(max_size, int) and max_size < 2: |
| 89 | + raise ValueError(f'{max_size=} should be >=2') |
| 90 | + strats = [] |
| 91 | + category_samples = { |
| 92 | + category: [d for d in dtypes if d in category] for category in _dtype_categories |
| 93 | + } |
| 94 | + for samples in category_samples.values(): |
| 95 | + if len(samples) > 0: |
| 96 | + strat = lists(sampled_from(samples), min_size=2, max_size=max_size) |
| 97 | + strats.append(strat) |
| 98 | + if len(category_samples[dh.uint_dtypes]) > 0 and len(category_samples[dh.int_dtypes]) > 0: |
| 99 | + mixed_samples = category_samples[dh.uint_dtypes] + category_samples[dh.int_dtypes] |
| 100 | + strat = lists(sampled_from(mixed_samples), min_size=2, max_size=max_size) |
| 101 | + if xp.uint64 in mixed_samples: |
| 102 | + strat = strat.filter( |
| 103 | + lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l)) |
| 104 | + ) |
| 105 | + return one_of(strats).map(tuple) |
| 106 | + |
84 | 107 |
|
85 | 108 | # shared() allows us to draw either the function or the function name and they |
86 | 109 | # will both correspond to the same function. |
@@ -324,9 +347,9 @@ def multiaxis_indices(draw, shapes): |
324 | 347 |
|
325 | 348 |
|
326 | 349 | def two_mutual_arrays( |
327 | | - dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes |
| 350 | + dtypes=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes |
328 | 351 | ): |
329 | | - mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs)) |
| 352 | + mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes)) |
330 | 353 | mutual_shapes = shared(two_shapes) |
331 | 354 | arrays1 = xps.arrays( |
332 | 355 | dtype=mutual_dtypes.map(lambda pair: pair[0]), |
|
0 commit comments