Skip to content

Commit

Permalink
Modifying cartesian product to allow for >2D input arrays (pymc-devs#…
Browse files Browse the repository at this point in the history
…4482)

* Modifying cartesian product to allow for more >2D input arrays
* Assert for equality in cartesian test
* Mention pymc-devs#4482 in new features

Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
  • Loading branch information
ckrapu and michaelosthege authored Feb 26, 2021
1 parent f0c823e commit 47b3658
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
+ ...

### New Features
+ `pm.math.cartesian` can now handle inputs that are themselves >1D (see [#4482](https://github.com/pymc-devs/pymc3/pull/4482)).
+ ...

### Maintenance
Expand Down
12 changes: 9 additions & 3 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,17 @@ def cartesian(*arrays):
Parameters
----------
arrays: 1D array-like
1D arrays where earlier arrays loop more slowly than later ones
arrays: N-D array-like
N-D arrays where earlier arrays loop more slowly than later ones
"""
N = len(arrays)
return np.stack(np.meshgrid(*arrays, indexing="ij"), -1).reshape(-1, N)
arrays_np = [np.asarray(x) for x in arrays]
arrays_2d = [x[:, None] if np.asarray(x).ndim == 1 else x for x in arrays_np]
arrays_integer = [np.arange(len(x)) for x in arrays_2d]
product_integers = np.stack(np.meshgrid(*arrays_integer, indexing="ij"), -1).reshape(-1, N)
return np.concatenate(
[array[product_integers[:, i]] for i, array in enumerate(arrays_2d)], axis=-1
)


def kron_matrix_op(krons, m, op):
Expand Down
19 changes: 18 additions & 1 deletion pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,24 @@ def test_cartesian():
]
)
auto_cart = cartesian(a, b, c)
np.testing.assert_array_almost_equal(manual_cartesian, auto_cart)
np.testing.assert_array_equal(manual_cartesian, auto_cart)


def test_cartesian_2d():
np.random.seed(1)
a = [[1, 2], [3, 4]]
b = [5, 6]
c = [0]
manual_cartesian = np.array(
[
[1, 2, 5, 0],
[1, 2, 6, 0],
[3, 4, 5, 0],
[3, 4, 6, 0],
]
)
auto_cart = cartesian(a, b, c)
np.testing.assert_array_equal(manual_cartesian, auto_cart)


def test_kron_dot():
Expand Down

0 comments on commit 47b3658

Please sign in to comment.