Skip to content

Commit

Permalink
numpy function imports
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Jun 3, 2024
1 parent f311bc9 commit ba72eb6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mellon/validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable
import logging

from jax.numpy import asarray, concatenate, isscalar, full, ndarray, where
from jax.numpy import asarray, concatenate, isscalar, full, ndarray, where, isnan, isinf
from jax.numpy import sum as arraysum
from jax.numpy import min as arraymin
from jax.numpy import all as arrayall
Expand Down Expand Up @@ -477,8 +477,8 @@ def _validate_nn_distances(nn_distances):

zeros = nn_distances == 0
# Check for invalid values
nan_count = np.isnan(nn_distances).sum()
inf_count = np.isinf(nn_distances).sum()
nan_count = isnan(nn_distances).sum()
inf_count = isinf(nn_distances).sum()
negative_count = (nn_distances < 0).sum()
if nan_count > 0 or inf_count > 0 or negative_count > 0:
total_invalid = nan_count + inf_count + negative_count
Expand Down

0 comments on commit ba72eb6

Please sign in to comment.