Skip to content

Commit

Permalink
Merge pull request #388 from jcapriot/is_scalar_bugfix
Browse files Browse the repository at this point in the history
improve scalar test to handle arbitrary dimensional ndarrays
  • Loading branch information
jcapriot authored Dec 17, 2024
2 parents 8d780d2 + 06e057c commit edba2be
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
6 changes: 5 additions & 1 deletion discretize/utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def is_scalar(f):
"""
if isinstance(f, SCALARTYPES):
return True
elif isinstance(f, np.ndarray) and f.size == 1 and isinstance(f[0], SCALARTYPES):
elif (
isinstance(f, np.ndarray)
and f.size == 1
and isinstance(f.reshape(-1)[0], SCALARTYPES)
):
return True
return False

Expand Down
12 changes: 9 additions & 3 deletions tests/base/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,15 @@ def test_is_scalar(self):
self.assertTrue(is_scalar(1.0))
self.assertTrue(is_scalar(1))
self.assertTrue(is_scalar(1j))
self.assertTrue(is_scalar(np.r_[1.0]))
self.assertTrue(is_scalar(np.r_[1]))
self.assertTrue(is_scalar(np.r_[1j]))
self.assertTrue(is_scalar(np.array(1.0)))
self.assertTrue(is_scalar(np.array(1)))
self.assertTrue(is_scalar(np.array(1j)))
self.assertTrue(is_scalar(np.array([1.0])))
self.assertTrue(is_scalar(np.array([1])))
self.assertTrue(is_scalar(np.array([1j])))
self.assertTrue(is_scalar(np.array([[1.0]])))
self.assertTrue(is_scalar(np.array([[1]])))
self.assertTrue(is_scalar(np.array([[1j]])))

def test_as_array_n_by_dim(self):
true = np.array([[1, 2, 3]])
Expand Down

0 comments on commit edba2be

Please sign in to comment.