Skip to content

Commit

Permalink
Use fixes._compare_version for version checks everywhere (#10091)
Browse files Browse the repository at this point in the history
Fixes #9894
  • Loading branch information
hoechenberger authored Dec 6, 2021
1 parent d7c8fe3 commit 8db9aac
Show file tree
Hide file tree
Showing 18 changed files with 49 additions and 56 deletions.
4 changes: 2 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import time
import warnings
from datetime import datetime, timezone
from distutils.version import LooseVersion

import numpy as np
import matplotlib
Expand All @@ -26,9 +25,10 @@
from mne.tests.test_docstring_parameters import error_ignores
from mne.utils import (linkcode_resolve, # noqa, analysis:ignore
_assert_no_instances, sizeof_fmt, run_subprocess)
from mne.fixes import _compare_version
from mne.viz import Brain # noqa

if LooseVersion(sphinx_gallery.__version__) < LooseVersion('0.2'):
if _compare_version(sphinx_gallery.__version__, '<', '0.2'):
raise ImportError('Must have at least version 0.2 of sphinx-gallery, got '
f'{sphinx_gallery.__version__}')

Expand Down
9 changes: 4 additions & 5 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# License: BSD-3-Clause

from contextlib import contextmanager
from distutils.version import LooseVersion
import inspect
from textwrap import dedent
import gc
Expand All @@ -22,7 +21,7 @@
from mne import read_events, pick_types, Epochs
from mne.channels import read_layout
from mne.datasets import testing
from mne.fixes import has_numba
from mne.fixes import has_numba, _compare_version
from mne.io import read_raw_fif, read_raw_ctf
from mne.stats import cluster_level
from mne.utils import (_pl, _assert_no_instances, numerics, Bunch,
Expand Down Expand Up @@ -202,7 +201,7 @@ def matplotlib_config():
class CallbackRegistryReraise(orig):
def __init__(self, exception_handler=None):
args = ()
if LooseVersion(matplotlib.__version__) >= LooseVersion('2.1'):
if _compare_version(matplotlib.__version__, '>=', '2.1'):
args += (exception_handler,)
super(CallbackRegistryReraise, self).__init__(*args)

Expand Down Expand Up @@ -396,7 +395,7 @@ def _check_pyqtgraph():
except ModuleNotFoundError:
pytest.skip('PyQt5 is not installed but needed for pyqtgraph!')
try:
assert LooseVersion(_check_pyqt5_version()) >= LooseVersion('5.12')
assert _compare_version(_check_pyqt5_version(), '>=', '5.12')
except AssertionError:
pytest.skip(f'PyQt5 has version {_check_pyqt5_version()}'
f'but pyqtgraph needs >= 5.12!')
Expand Down Expand Up @@ -665,7 +664,7 @@ def brain_gc(request):
yield
return
import pyvista
if LooseVersion(pyvista.__version__) <= LooseVersion('0.26.1'):
if _compare_version(pyvista.__version__, '<=', '0.26.1'):
yield
return
from mne.viz import Brain
Expand Down
5 changes: 2 additions & 3 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# License: BSD-3-Clause

from copy import deepcopy
from distutils.version import LooseVersion
import itertools as itt
from math import log
import os
Expand Down Expand Up @@ -42,7 +41,7 @@
from . import viz

from .fixes import (BaseEstimator, EmpiricalCovariance, _logdet,
empirical_covariance, log_likelihood)
empirical_covariance, log_likelihood, _compare_version)


def _check_covs_algebra(cov1, cov2):
Expand Down Expand Up @@ -1020,7 +1019,7 @@ def _eigvec_subspace(eig, eigvec, mask):
def _get_iid_kwargs():
import sklearn
kwargs = dict()
if LooseVersion(sklearn.__version__) < LooseVersion('0.22'):
if _compare_version(sklearn.__version__, '<', '0.22'):
kwargs['iid'] = False
return kwargs

Expand Down
6 changes: 3 additions & 3 deletions mne/datasets/_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import os
import os.path as op
from distutils.version import LooseVersion
from shutil import rmtree

from .. import __version__ as mne_version
Expand All @@ -17,6 +16,7 @@
MISC_VERSIONED,
)
from .utils import _dataset_version, _do_path_update, _get_path
from ..fixes import _compare_version


_FAKE_VERSION = None # used for monkeypatching while testing versioning
Expand Down Expand Up @@ -171,7 +171,7 @@ def fetch_dataset(
# get the version of the dataset and then check if the version is outdated
data_version = _dataset_version(final_path, name)
outdated = (want_version is not None and
LooseVersion(want_version) > LooseVersion(data_version))
_compare_version(want_version, '>', data_version))

if outdated:
logger.info(
Expand Down Expand Up @@ -273,7 +273,7 @@ def fetch_dataset(
data_version = _dataset_version(path, name)
# 0.7 < 0.7.git should be False, therefore strip
if check_version and (
LooseVersion(data_version) < LooseVersion(mne_version.strip(".git"))
_compare_version(data_version, '<', mne_version.strip(".git"))
):
warn(
"The {name} dataset (version {current}) is older than "
Expand Down
9 changes: 4 additions & 5 deletions mne/datasets/sleep_physionet/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import os
import os.path as op
import numpy as np
from distutils.version import LooseVersion

from ...utils import (verbose, _TempDir, _check_pandas_installed,
_on_missing)
from ...utils.check import _soft_import
from ..utils import _get_path

from ...fixes import _compare_version

AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), 'age_records.csv')
TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__),
Expand Down Expand Up @@ -110,7 +109,7 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS):

# Load and massage the data.
data = pd.read_excel(subjects_fname, header=[0, 1])
if LooseVersion(pd.__version__) >= LooseVersion('0.24.0'):
if _compare_version(pd.__version__, '>=', '0.24.0'):
data = data.set_index(('Subject - age - sex', 'Nr'))
data.index.name = 'subject'
data.columns.names = [None, None]
Expand All @@ -132,7 +131,7 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS):
data = data.set_index(['id', 'subject', 'age', 'sex', 'drug',
'lights off', 'night nr', 'record type']).unstack()
data.columns = [l1 + '_' + l2 for l1, l2 in data.columns]
if LooseVersion(pd.__version__) < LooseVersion('0.21.0'):
if _compare_version(pd.__version__, '<' '0.21.0'):
data = data.reset_index().drop(labels=['id'], axis=1)
else:
data = data.reset_index().drop(columns=['id'])
Expand Down Expand Up @@ -189,7 +188,7 @@ def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS):
.str.split('.', expand=True)[0]
.astype('category'))

if LooseVersion(pd.__version__) < LooseVersion('0.21.0'):
if _compare_version(pd.__version__, '<', '0.21.0'):
data = data.reset_index().drop(labels=['id'], axis=1)
else:
data = data.reset_index().drop(columns=['id'])
Expand Down
4 changes: 2 additions & 2 deletions mne/externals/tqdm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from distutils.version import LooseVersion
from ...fixes import _compare_version

try:
from tqdm import __version__
if LooseVersion(__version__) < LooseVersion('4.36'):
if _compare_version(__version__, '<', '4.36'):
raise ImportError
except ImportError: # use our copy
from ._tqdm import *
Expand Down
4 changes: 2 additions & 2 deletions mne/io/eeglab/tests/test_eeglab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# License: BSD-3-Clause

from copy import deepcopy
from distutils.version import LooseVersion
import os.path as op
import shutil
from unittest import SkipTest
Expand All @@ -25,6 +24,7 @@
from mne.utils import check_version, Bunch
from mne.annotations import events_from_annotations, read_annotations
from mne.externals.pymatreader import read_mat
from mne.fixes import _compare_version

base_dir = op.join(testing.data_path(download=False), 'EEGLAB')

Expand Down Expand Up @@ -221,7 +221,7 @@ def test_io_set_raw_more(tmp_path):
# SciPy writing bug that pops up sometimes:
# nopos_chanlocs = np.array(chanlocs[['labels', 'Z']])

if LooseVersion(np.__version__) == '1.14.0':
if _compare_version(np.__version__, '==', '1.14.0'):
# There is a bug in 1.14.0 (or maybe with SciPy 1.0.0?) that causes
# this write to fail!
raise SkipTest('Need to fix bug in NumPy 1.14.0!')
Expand Down
5 changes: 2 additions & 3 deletions mne/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# C code.

from copy import deepcopy
from distutils.version import LooseVersion
from functools import partial, lru_cache
from collections import OrderedDict
from glob import glob
Expand All @@ -21,7 +20,7 @@

from .channels.channels import _get_meg_system
from .fixes import (_serialize_volume_info, _get_read_geometry, jit,
prange, bincount)
prange, bincount, _compare_version)
from .io.constants import FIFF
from .io.pick import pick_types
from .parallel import parallel_func
Expand Down Expand Up @@ -1101,7 +1100,7 @@ def write_surface(fname, coords, faces, create_stamp='', volume_info=None,
if file_format == 'freesurfer':
try:
import nibabel as nib
has_nibabel = LooseVersion(nib.__version__) > LooseVersion('2.1.0')
has_nibabel = _compare_version(nib.__version__, '>', '2.1.0')
except ImportError:
has_nibabel = False
if has_nibabel:
Expand Down
4 changes: 2 additions & 2 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# License: BSD-3-Clause

from copy import deepcopy
from distutils.version import LooseVersion
from functools import partial
from io import BytesIO
import os
Expand Down Expand Up @@ -42,6 +41,7 @@
from mne.utils import (requires_pandas, object_diff,
catch_logging, _FakeNoPandas,
assert_meg_snr, check_version, _dt_to_stamp)
from mne.fixes import _compare_version

data_path = testing.data_path(download=False)
fname_raw_testing = op.join(data_path, 'MEG', 'sample',
Expand Down Expand Up @@ -3106,7 +3106,7 @@ def assert_metadata_equal(got, exp):
assert isinstance(exp, pandas.DataFrame)
assert isinstance(got, pandas.DataFrame)
assert set(got.columns) == set(exp.columns)
if LooseVersion(pandas.__version__) < LooseVersion('0.25'):
if _compare_version(pandas.__version__, '<', '0.25'):
# Old Pandas does not necessarily order them properly
got = got[exp.columns]
check = (got == exp)
Expand Down
5 changes: 2 additions & 3 deletions mne/time_frequency/tests/test_multitaper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from distutils.version import LooseVersion

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
Expand All @@ -8,6 +6,7 @@
from mne.time_frequency.multitaper import dpss_windows
from mne.utils import requires_nitime
from mne.io import RawArray
from mne.fixes import _compare_version
from mne import create_info


Expand Down Expand Up @@ -48,7 +47,7 @@ def test_multitaper_psd():
raw = RawArray(data, info)
pytest.raises(ValueError, psd_multitaper, raw, sfreq,
normalization='foo')
ni_5 = (LooseVersion(ni.__version__) >= LooseVersion('0.5'))
ni_5 = _compare_version(ni.__version__, '>=', '0.5')
norm = 'full' if ni_5 else 'length'
for adaptive, n_jobs in zip((False, True, True), (1, 1, 2)):
psd, freqs = psd_multitaper(raw, adaptive=adaptive,
Expand Down
11 changes: 5 additions & 6 deletions mne/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# License: BSD-3-Clause

from contextlib import contextmanager
from distutils.version import LooseVersion
from functools import partial, wraps
import os
import inspect
Expand All @@ -21,6 +20,7 @@

from ._logging import warn, ClosingStringIO
from .numerics import object_diff
from ..fixes import _compare_version


def _explain_exception(start=-1, stop=None, prefix='> '):
Expand Down Expand Up @@ -109,8 +109,8 @@ def requires_module(function, name, call=None):

_pandas_call = """
import pandas
version = LooseVersion(pandas.__version__)
if version < '0.8.0':
version = pandas.__version__
if _compare_version(version, '<', '0.8.0'):
raise ImportError
"""

Expand Down Expand Up @@ -193,9 +193,8 @@ def check_version(library, min_version):
ok = False
else:
if min_version:
this_version = LooseVersion(
getattr(library, '__version__', '0.0').lstrip('v'))
if this_version < min_version:
this_version = getattr(library, '__version__', '0.0').lstrip('v')
if _compare_version(this_version, '<', min_version):
ok = False
return ok

Expand Down
4 changes: 2 additions & 2 deletions mne/utils/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
#
# License: BSD-3-Clause

from distutils.version import LooseVersion
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
from scipy import linalg
import pytest

from mne.utils import _sym_mat_pow, _reg_pinv, requires_version
from mne.fixes import _compare_version


@requires_version('numpy', '1.17') # pinv bugs
Expand All @@ -28,7 +28,7 @@
])
def test_pos_semidef_inv(ndim, dtype, n, deficient, reduce_rank, psdef, func):
"""Test positive semidefinite matrix inverses."""
if LooseVersion(np.__version__) >= LooseVersion('1.19'):
if _compare_version(np.__version__, '>=', '1.19'):
svd = np.linalg.svd
else:
from mne.fixes import svd
Expand Down
9 changes: 5 additions & 4 deletions mne/viz/_brain/_scraper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import os.path as op
from distutils.version import LooseVersion

from ...fixes import _compare_version
from ._brain import Brain


Expand Down Expand Up @@ -70,9 +70,10 @@ def func(frame):
# Out to sphinx-gallery:
#
# 1. A static image but hide it (useful for carousel)
if LooseVersion(matplotlib.__version__) >= \
LooseVersion('3.3.1') and \
animation.FFMpegWriter.isAvailable():
if (
_compare_version(matplotlib.__version__, '>=', '3.3.1') and
animation.FFMpegWriter.isAvailable()
):
writer = 'ffmpeg'
elif animation.ImageMagickWriter.isAvailable():
writer = 'imagemagick'
Expand Down
4 changes: 2 additions & 2 deletions mne/viz/backends/_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# License: Simplified BSD

from contextlib import contextmanager, nullcontext
from distutils.version import LooseVersion

import pyvista
from IPython.display import display
from ipywidgets import (Button, Dropdown, FloatSlider, BoundedFloatText, HBox,
IntSlider, IntText, Text, VBox, IntProgress, Play,
Checkbox, RadioButtons, jsdlink)

from ...fixes import _compare_version
from ._abstract import (_AbstractDock, _AbstractToolBar, _AbstractMenuBar,
_AbstractStatusBar, _AbstractLayout, _AbstractWidget,
_AbstractWindow, _AbstractMplCanvas, _AbstractPlayback,
Expand Down Expand Up @@ -461,7 +461,7 @@ def show(self):
self._create_default_tool_bar()
display(self._tool_bar)
# viewer
if LooseVersion(pyvista.__version__) < LooseVersion('0.30'):
if _compare_version(pyvista.__version__, '<', '0.30'):
viewer = self.plotter.show(
use_ipyvtk=True, return_viewer=True)
else: # pyvista>=0.30.0
Expand Down
Loading

0 comments on commit 8db9aac

Please sign in to comment.