Skip to content

Commit

Permalink
helpers: avoid mutation in invertible_matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 13, 2023
1 parent f82c7bc commit bd4ab55
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import itertools
from contextlib import contextmanager
from functools import reduce
from math import sqrt
Expand Down Expand Up @@ -267,18 +266,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
# For now, just generate stacks of diagonal matrices.
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
stack_shape = draw(stack_shapes)
shape = stack_shape + (n, n)
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
d = draw(xps.arrays(dtypes, shape=(*stack_shape, 1, n),
elements=dict(allow_nan=False, allow_infinity=False)))
# Functions that require invertible matrices may do anything when it is
# singular, including raising an exception, so we make sure the diagonals
# are sufficiently nonzero to avoid any numerical issues.
assume(xp.all(xp.abs(d) > 0.5))

a = xp.zeros(shape)
for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))):
a[idx + (i, i)] = d[j]
return a
diag_mask = xp.arange(n) == xp.arange(n)[:, None]
return xp.where(diag_mask, d, xp.zeros_like(d))

# TODO: Better name
@composite
Expand Down

0 comments on commit bd4ab55

Please sign in to comment.