Skip to content

Commit

Permalink
Merge pull request #11 from nschloe/isin
Browse files Browse the repository at this point in the history
isin_rows
  • Loading branch information
nschloe authored Apr 24, 2021
2 parents 55740ed + 4e74f49 commit 1a67e8d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 4 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ If you have a fix for a NumPy method that can't go upstream for some reason, fee
to PR here.


#### `np.dot`
#### `dot`
```python
npx.dot(a, b)
```
Expand All @@ -36,7 +36,7 @@ Solves a linear equation system with a matrix of shape `(n, n)` and an array of
`(n, ...)`. The output has the same shape as the second argument.


#### `np.ufunc.at`
#### `sum_at`/`add_at`
```python
npx.sum_at(a, idx, minlength=0)
npx.add_at(out, idx, a)
Expand All @@ -53,7 +53,7 @@ slower:
Corresponding report: https://github.com/numpy/numpy/issues/11156.


#### `np.unique`
#### `unique_rows`
```python
npx.unique_rows(a, return_inverse=False, return_counts=False)
```
Expand All @@ -62,6 +62,15 @@ axis=0)` is slow.

Corresponding report: https://github.com/numpy/numpy/issues/11136.


#### `isin_rows`
```python
npx.isin_rows(a, b)
```
Returns a boolean array of length `len(a)` specifying if the rows `a[k]` appear in `b`.
Similar to NumPy's own `np.isin` which only works for scalars.


#### SciPy Krylov methods
```python
sol, info = npx.cg(A, b, tol=1.0e-10)
Expand Down
2 changes: 2 additions & 0 deletions npx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .__about__ import __version__
from ._isin import isin_rows
from ._krylov import cg, gmres, minres
from ._main import add_at, dot, solve, subtract_at, sum_at, unique_rows
from ._minimize import minimize
Expand All @@ -12,6 +13,7 @@
"add_at",
"subtract_at",
"unique_rows",
"isin_rows",
"cg",
"gmres",
"minres",
Expand Down
24 changes: 24 additions & 0 deletions npx/_isin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np


def isin_rows(a, b):
a = np.asarray(a)
b = np.asarray(b)
if not np.issubdtype(a.dtype, np.integer):
raise ValueError(f"Input array must be integer type, got {a.dtype}.")
if not np.issubdtype(b.dtype, np.integer):
raise ValueError(f"Input array must be integer type, got {b.dtype}.")

a = a.reshape(a.shape[0], np.prod(a.shape[1:], dtype=int))
b = b.reshape(b.shape[0], np.prod(b.shape[1:], dtype=int))

a_view = np.ascontiguousarray(a).view(
np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
)
b_view = np.ascontiguousarray(b).view(
np.dtype((np.void, b.dtype.itemsize * b.shape[1]))
)

out = np.isin(a_view, b_view)

return out.reshape(a.shape[0])
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.9
version = 0.0.10
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand Down
19 changes: 19 additions & 0 deletions test/test_isin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import npx


def test_isin():
a = [[0, 3], [1, 0]]
b = [[1, 0], [7, 12], [-1, 5]]

out = npx.isin_rows(a, b)
assert np.all(out == [False, True])


def test_scalar():
a = [0, 3, 5]
b = [-1, 6, 5, 0, 0, 0]

out = npx.isin_rows(a, b)
assert np.all(out == [True, False, True])

0 comments on commit 1a67e8d

Please sign in to comment.