diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c1b5b29c7bf..c1f0cc11f54 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -143,14 +143,11 @@ def _color_palette(cmap, n_colors): pal = cmap(colors_i) except ValueError: # ValueError happens when mpl doesn't like a colormap, try seaborn - if TYPE_CHECKING: - import seaborn as sns - else: - sns = attempt_import("seaborn") - try: - pal = sns.color_palette(cmap, n_colors=n_colors) - except ValueError: + from seaborn import color_palette + + pal = color_palette(cmap, n_colors=n_colors) + except (ValueError, ImportError): # or maybe we just got a single color as a string cmap = ListedColormap([cmap] * n_colors) pal = cmap(colors_i) @@ -192,7 +189,10 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - import matplotlib as mpl + if TYPE_CHECKING: + import matplotlib as mpl + else: + mpl = attempt_import("matplotlib") if isinstance(levels, Iterable): levels = sorted(levels)