From ba72eb6a193ee047b285a78e4ab477ffa350bfc2 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 3 Jun 2024 11:43:49 -0700 Subject: [PATCH] numpy function imports --- mellon/validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mellon/validation.py b/mellon/validation.py index 7306769..9e69b96 100644 --- a/mellon/validation.py +++ b/mellon/validation.py @@ -1,7 +1,7 @@ from collections.abc import Iterable import logging -from jax.numpy import asarray, concatenate, isscalar, full, ndarray, where +from jax.numpy import asarray, concatenate, isscalar, full, ndarray, where, isnan, isinf from jax.numpy import sum as arraysum from jax.numpy import min as arraymin from jax.numpy import all as arrayall @@ -477,8 +477,8 @@ def _validate_nn_distances(nn_distances): zeros = nn_distances == 0 # Check for invalid values - nan_count = np.isnan(nn_distances).sum() - inf_count = np.isinf(nn_distances).sum() + nan_count = isnan(nn_distances).sum() + inf_count = isinf(nn_distances).sum() negative_count = (nn_distances < 0).sum() if nan_count > 0 or inf_count > 0 or negative_count > 0: total_invalid = nan_count + inf_count + negative_count