Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: Reduce number of calls to _update #9407

Merged
merged 13 commits into from
May 24, 2021
2 changes: 1 addition & 1 deletion mne/morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def _hemi_morph(tris, vertices_to, vertices_from, smooth, maps, warn):
e = mesh_edges(tris)
e.data[e.data == 2] = 1
n_vertices = e.shape[0]
e = e + sparse.eye(n_vertices)
e += sparse.eye(n_vertices, format='csr')
if isinstance(smooth, str):
_check_option('smooth', smooth, ('nearest',),
extra=' when used as a string.')
Expand Down
17 changes: 10 additions & 7 deletions mne/source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_CheckInside)
from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc,
_ensure_int, check_version, _get_call_line, warn,
_check_fname, _check_path_like, has_nibabel, _check_sphere,
_check_fname, _check_path_like, _check_sphere,
_validate_type, _check_option, _is_numeric, _pl, _suggest,
object_size, sizeof_fmt)
from .parallel import parallel_func, check_n_jobs
Expand Down Expand Up @@ -1357,17 +1357,17 @@ def read_talxfm(subject, subjects_dir=None, verbose=None):
return mri_mni_t


def _read_mri_info(path, units='m', return_img=False):
if has_nibabel():
def _read_mri_info(path, units='m', return_img=False, use_nibabel=False):
# This is equivalent but 100x slower, so only use nibabel if we need to
# (later):
if use_nibabel:
import nibabel
mgz = nibabel.load(path)
hdr = mgz.header
hdr = nibabel.load(path).header
n_orig = hdr.get_vox2ras()
t_orig = hdr.get_vox2ras_tkr()
dims = hdr.get_data_shape()
zooms = hdr.get_zooms()[:3]
else:
mgz = None
hdr = _get_mgz_header(path)
n_orig = hdr['vox2ras']
t_orig = hdr['vox2ras_tkr']
Expand Down Expand Up @@ -1395,7 +1395,8 @@ def _read_mri_info(path, units='m', return_img=False):

out = (vox_ras_t, vox_mri_t, mri_ras_t, dims, zooms)
if return_img:
out += (mgz,)
nibabel = _import_nibabel()
out += (nibabel.load(path),)
return out


Expand Down Expand Up @@ -2250,6 +2251,8 @@ def _vol_vertex(width, height, jj, kk, pp):

def _get_mgz_header(fname):
"""Adapted from nibabel to quickly extract header info."""
fname = _check_fname(fname, overwrite='read', must_exist=True,
name='MRI image')
if not fname.endswith('.mgz'):
raise IOError('Filename must end with .mgz')
header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)),
Expand Down
49 changes: 29 additions & 20 deletions mne/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from copy import deepcopy
from distutils.version import LooseVersion
from functools import partial
from functools import partial, lru_cache
from glob import glob
from os import path as op
from struct import pack
Expand All @@ -28,7 +28,7 @@
_get_trans, apply_trans, Transform)
from .utils import (logger, verbose, get_subjects_dir, warn, _check_fname,
_check_option, _ensure_int, _TempDir, run_subprocess,
_check_freesurfer_home)
_check_freesurfer_home, _hashable_ndarray)


###############################################################################
Expand Down Expand Up @@ -343,17 +343,19 @@ def _normal_orth(nn):

@verbose
def complete_surface_info(surf, do_neighbor_vert=False, copy=True,
verbose=None):
do_neighbor_tri=True, *, verbose=None):
"""Complete surface information.

Parameters
----------
surf : dict
The surface.
do_neighbor_vert : bool
If True, add neighbor vertex information.
If True (default False), add neighbor vertex information.
copy : bool
If True (default), make a copy. If False, operate in-place.
do_neighbor_tri : bool
If True (default), compute triangle neighbors.
%(verbose)s

Returns
Expand Down Expand Up @@ -383,27 +385,28 @@ def complete_surface_info(surf, do_neighbor_vert=False, copy=True,

# Find neighboring triangles, accumulate vertex normals, normalize
logger.info(' Triangle neighbors and vertex normals...')
surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np'])
surf['nn'] = _accumulate_normals(surf['tris'].astype(int),
surf['tri_nn'], surf['np'])
_normalize_vectors(surf['nn'])

# Check for topological defects
zero, fewer = list(), list()
for ni, n in enumerate(surf['neighbor_tri']):
if len(n) < 3:
if len(n) == 0:
zero.append(ni)
else:
fewer.append(ni)
surf['neighbor_tri'][ni] = np.array([], int)
if len(zero) > 0:
logger.info(' Vertices do not have any neighboring '
'triangles: [%s]' % ', '.join(str(z) for z in zero))
if len(fewer) > 0:
logger.info(' Vertices have fewer than three neighboring '
'triangles, removing neighbors: [%s]'
% ', '.join(str(f) for f in fewer))
if do_neighbor_tri:
surf['neighbor_tri'] = _triangle_neighbors(surf['tris'], surf['np'])
zero, fewer = list(), list()
for ni, n in enumerate(surf['neighbor_tri']):
if len(n) < 3:
if len(n) == 0:
zero.append(ni)
else:
fewer.append(ni)
surf['neighbor_tri'][ni] = np.array([], int)
if len(zero) > 0:
logger.info(' Vertices do not have any neighboring '
'triangles: [%s]' % ', '.join(str(z) for z in zero))
if len(fewer) > 0:
logger.info(' Vertices have fewer than three neighboring '
'triangles, removing neighbors: [%s]'
% ', '.join(str(f) for f in fewer))

# Determine the neighboring vertices and fix errors
if do_neighbor_vert is True:
Expand Down Expand Up @@ -1451,6 +1454,12 @@ def mesh_edges(tris):
edges : sparse matrix
The adjacency matrix.
"""
tris = _hashable_ndarray(tris)
return _mesh_edges(tris=tris)


@lru_cache(maxsize=10)
def _mesh_edges(tris=None):
from scipy.sparse import coo_matrix
if np.max(tris) > len(np.unique(tris)):
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions mne/tests/test_coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def make_dig(coords, cf):
assert_array_almost_equal(trans_est['trans'], trans['trans'])


@requires_nibabel()
@pytest.mark.slowtest # can take forever on OSX Travis
@testing.requires_testing_data
@pytest.mark.parametrize('scale', (.9, [1, .2, .8]))
Expand Down
7 changes: 6 additions & 1 deletion mne/tests/test_source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def test_discrete_source_space(tmpdir):
assert _get_src_type(src_new, None) == 'discrete'


@requires_nibabel()
@pytest.mark.slowtest
@testing.requires_testing_data
def test_volume_source_space(tmpdir):
Expand Down Expand Up @@ -593,6 +594,7 @@ def test_head_to_mni():
assert_allclose(coords_MNI, coords_MNI_2, atol=10.0)


@requires_nibabel()
@testing.requires_testing_data
def test_vertex_to_mni_fs_nibabel(monkeypatch):
"""Test equivalence of vert_to_mni for nibabel and freesurfer."""
Expand All @@ -601,7 +603,10 @@ def test_vertex_to_mni_fs_nibabel(monkeypatch):
vertices = rng.randint(0, 100000, n_check)
hemis = rng.randint(0, 1, n_check)
coords = vertex_to_mni(vertices, hemis, subject, subjects_dir)
monkeypatch.setattr(mne.source_space, 'has_nibabel', lambda: False)
read_mri = mne.source_space._read_mri_info
monkeypatch.setattr(
mne.source_space, '_read_mri_info',
lambda *args, **kwargs: read_mri(*args, use_nibabel=True, **kwargs))
coords_2 = vertex_to_mni(vertices, hemis, subject, subjects_dir)
# less than 0.1 mm error
assert_allclose(coords, coords_2, atol=0.1)
Expand Down
2 changes: 1 addition & 1 deletion mne/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_mask_to_onsets_offsets, _array_equal_nan,
_julian_to_cal, _cal_to_julian, _dt_to_julian,
_julian_to_dt, _dt_to_stamp, _stamp_to_dt,
_check_dt, _ReuseCycle, _arange_div)
_check_dt, _ReuseCycle, _arange_div, _hashable_ndarray)
from .mixin import (SizeMixin, GetEpochsMixin, _prepare_read_metadata,
_prepare_write_metadata, _FakeNoPandas, ShiftTimeMixin)
from .linalg import (_svd_lwork, _repeated_svd, _sym_mat_pow, sqrtm_sym, eigh,
Expand Down
12 changes: 12 additions & 0 deletions mne/utils/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,18 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
return grand_average


class _HashableNdarray(np.ndarray):
def __hash__(self):
return object_hash(self)

def __eq__(self, other):
return NotImplementedError # defer to hash


def _hashable_ndarray(x):
return x.view(_HashableNdarray)


def object_hash(x, h=None):
"""Hash a reasonable python object.

Expand Down
Loading