Skip to content

Commit

Permalink
make sure n and m are integer-valued
Browse files Browse the repository at this point in the history
  • Loading branch information
jiadongdan committed Sep 20, 2024
1 parent 0e31338 commit 936af1a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
16 changes: 13 additions & 3 deletions mtflearn/features/_zmoments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

import numpy as np

def nm2j(n, m):
"""
Convert Zernike radial order `n` and azimuthal frequency `m` to a single index `j`.
Expand Down Expand Up @@ -36,12 +38,21 @@ def nm2j(n, m):
>>> nm2j([0, 1, 1], [0, -1, 1])
array([0, 1, 2])
"""
n = np.asarray(n, dtype=int)
m = np.asarray(m, dtype=int)
n = np.asarray(n)
m = np.asarray(m)

if n.shape != m.shape:
raise ValueError("`n` and `m` must have the same shape.")

# Validate that n and m are integer-valued
if not np.all(np.isclose(n % 1, 0)):
raise ValueError("Radial order `n` must be integer-valued.")
if not np.all(np.isclose(m % 1, 0)):
raise ValueError("Azimuthal frequency `m` must be integer-valued.")

n = n.astype(int)
m = m.astype(int)

# Validate inputs
if np.any(n < 0):
raise ValueError("Radial order `n` must be non-negative.")
Expand All @@ -59,7 +70,6 @@ def nm2j(n, m):
else:
return j


def nm2j_complex(n, m):
n = np.atleast_1d(n)
m = np.atleast_1d(m)
Expand Down
6 changes: 3 additions & 3 deletions mtflearn/features/test_zmoments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ def test_scalar_inputs():
assert nm2j(0, 0) == 0
assert nm2j(1, -1) == 1
assert nm2j(2, 0) == 4
assert nm2j(3, 1) == 7
assert nm2j(4, -4) == 8
assert nm2j(5, 3) == 13
assert nm2j(3, 1) == 8
assert nm2j(4, -4) == 10
assert nm2j(5, 3) == 19

def test_array_inputs():
"""Test function with valid array inputs."""
Expand Down

0 comments on commit 936af1a

Please sign in to comment.