Skip to content

Commit 8ff8113

Browse files
authored
Speed up missing._get_interpolator (#4776)
* Speed up _get_interpolator Importing scipy.interpolate is slow and should only be done when necessary. Test case from 200ms to 6ms. * typos * retain info from the except.
1 parent 5ddb8d5 commit 8ff8113

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

xarray/core/missing.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,16 @@ def bfill(arr, dim=None, limit=None):
437437
).transpose(*arr.dims)
438438

439439

440+
def _import_interpolant(interpolant, method):
441+
"""Import interpolant from scipy.interpolate."""
442+
try:
443+
from scipy import interpolate
444+
445+
return getattr(interpolate, interpolant)
446+
except ImportError as e:
447+
raise ImportError(f"Interpolation with method {method} requires scipy.") from e
448+
449+
440450
def _get_interpolator(method, vectorizeable_only=False, **kwargs):
441451
"""helper function to select the appropriate interpolator class
442452
@@ -459,12 +469,6 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
459469
"akima",
460470
]
461471

462-
has_scipy = True
463-
try:
464-
from scipy import interpolate
465-
except ImportError:
466-
has_scipy = False
467-
468472
# prioritize scipy.interpolate
469473
if (
470474
method == "linear"
@@ -475,32 +479,29 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
475479
interp_class = NumpyInterpolator
476480

477481
elif method in valid_methods:
478-
if not has_scipy:
479-
raise ImportError("Interpolation with method `%s` requires scipy" % method)
480-
481482
if method in interp1d_methods:
482483
kwargs.update(method=method)
483484
interp_class = ScipyInterpolator
484485
elif vectorizeable_only:
485486
raise ValueError(
486-
"{} is not a vectorizeable interpolator. "
487-
"Available methods are {}".format(method, interp1d_methods)
487+
f"{method} is not a vectorizeable interpolator. "
488+
f"Available methods are {interp1d_methods}"
488489
)
489490
elif method == "barycentric":
490-
interp_class = interpolate.BarycentricInterpolator
491+
interp_class = _import_interpolant("BarycentricInterpolator", method)
491492
elif method == "krog":
492-
interp_class = interpolate.KroghInterpolator
493+
interp_class = _import_interpolant("KroghInterpolator", method)
493494
elif method == "pchip":
494-
interp_class = interpolate.PchipInterpolator
495+
interp_class = _import_interpolant("PchipInterpolator", method)
495496
elif method == "spline":
496497
kwargs.update(method=method)
497498
interp_class = SplineInterpolator
498499
elif method == "akima":
499-
interp_class = interpolate.Akima1DInterpolator
500+
interp_class = _import_interpolant("Akima1DInterpolator", method)
500501
else:
501-
raise ValueError("%s is not a valid scipy interpolator" % method)
502+
raise ValueError(f"{method} is not a valid scipy interpolator")
502503
else:
503-
raise ValueError("%s is not a valid interpolator" % method)
504+
raise ValueError(f"{method} is not a valid interpolator")
504505

505506
return interp_class, kwargs
506507

@@ -512,18 +513,13 @@ def _get_interpolator_nd(method, **kwargs):
512513
"""
513514
valid_methods = ["linear", "nearest"]
514515

515-
try:
516-
from scipy import interpolate
517-
except ImportError:
518-
raise ImportError("Interpolation with method `%s` requires scipy" % method)
519-
520516
if method in valid_methods:
521517
kwargs.update(method=method)
522-
interp_class = interpolate.interpn
518+
interp_class = _import_interpolant("interpn", method)
523519
else:
524520
raise ValueError(
525-
"%s is not a valid interpolator for interpolating "
526-
"over multiple dimensions." % method
521+
f"{method} is not a valid interpolator for interpolating "
522+
"over multiple dimensions."
527523
)
528524

529525
return interp_class, kwargs

0 commit comments

Comments
 (0)