Skip to content

Commit

Permalink
MRG, ENH: Speed up morph map creation (#6612)
Browse files Browse the repository at this point in the history
* WIP: Refactor nearest functions

* ENH: Speed up morph map creation

* FIX: link

* FIX: Always use int64

* ENH: Even faster

* ENH: No longer need to warn about how terribly slow our code is

* ENH: Update sys info and better check
  • Loading branch information
larsoner authored Aug 3, 2019
1 parent 7a79063 commit 9351930
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ build
coverage
.cache/
.pytest_cache/
__pycache__/
prof/

dist/
Expand Down
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ For full functionality, some functions require:
- Mayavi >= 4.6
- PySurfer >= 0.8
- Scikit-learn >= 0.18
- Numba >= 0.40
- NiBabel >= 2.1.0
- Pandas >= 0.18
- Picard >= 0.3
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@
'scipy': ('https://scipy.github.io/devdocs', None),
'matplotlib': ('https://matplotlib.org', None),
'sklearn': ('https://scikit-learn.org/stable', None),
'numba': ('https://numba.pydata.org/numba-doc/latest', None),
'joblib': ('https://joblib.readthedocs.io/en/latest', None),
'mayavi': ('http://docs.enthought.com/mayavi/mayavi', None),
'nibabel': ('https://nipy.org/nibabel', None),
Expand Down
5 changes: 4 additions & 1 deletion doc/install/mne_python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,14 @@ MNE-Python and its dependencies. Typical output looks like this::
matplotlib: 3.0.2 {backend=Qt5Agg}

sklearn: 0.20.2
numba: 0.45.0
nibabel: 2.3.3
mayavi: 4.7.0.dev0 {qt_api=pyqt5, PyQt5=5.10.1}
cupy: Not found
pandas: 0.24.0
dipy: 0.15.0
mayavi: 4.7.1 {qt_api=pyqt5, PyQt5=5.10.1}
pyvista: 0.21.3
vtk: 8.2.0

.. collapse:: |hand-stop-o| If you get an error...
:class: danger
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Changelog

- Accept filenames of raw .fif files that end in ``_meg.fif`` to enable complicance with the Brain Imaging Data Structure by `Stefan Appelhoff`_

- Speed up morph map generation in :func:`mne.read_morph_map` by ~5-10x by using :func:`numba.jit` by `Eric Larson`_

- Add :class:`mne.digitization.Digitization` class to simplify montage by `Joan Massich`_

- Add :func:`mne.dig_mri_distances` to compute the distances between digitized head points and the MRI head surface by `Alex Gramfort`_ and `Eric Larson`_
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- numpy
- scipy
- matplotlib
- cython
- numba
- pyqt>=5.9
- pandas>=0.18
- xlrd
Expand Down
21 changes: 21 additions & 0 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,3 +1264,24 @@ def _get_status(checks):
return list(checks.get_status())
except AttributeError:
return [x[0].get_visible() for x in checks.lines]


###############################################################################
# Numba (optional requirement)

# Here we choose different defaults to speed things up by default
try:
import numba
if LooseVersion(numba.__version__) < LooseVersion('0.40'):
raise ImportError
prange = numba.prange
def jit(nopython=True, nogil=True, fastmath=True, cache=True,
**kwargs): # noqa
return numba.jit(nopython=nopython, nogil=nogil, fastmath=fastmath,
cache=cache, **kwargs)
except ImportError:
def jit(**kwargs): # noqa
def _jit(func):
return func
return _jit
prange = range
4 changes: 2 additions & 2 deletions mne/forward/_compute_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def _bem_specify_els(bem, els, mults):
for el in els], axis=0)
ws = np.concatenate([el['w'] for el in els])
tri_weights, tri_idx = _project_onto_surface(rrs, scalp)
tri_weights *= ws
weights = einsum('ij,jik->jk', tri_weights,
tri_weights *= ws[:, np.newaxis]
weights = einsum('ij,ijk->ik', tri_weights,
bem['solution'][scalp['tris'][tri_idx]])
# there are way more vertices than electrodes generally, so let's iterate
# over the electrodes
Expand Down
186 changes: 103 additions & 83 deletions mne/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from .transforms import (transform_surface_to, _pol_to_cart, _cart_to_sph,
_get_trans, apply_trans)
from .utils import logger, verbose, get_subjects_dir, warn
from .fixes import _serialize_volume_info, _get_read_geometry, einsum
from .fixes import (_serialize_volume_info, _get_read_geometry, einsum, jit,
prange)


###############################################################################
Expand Down Expand Up @@ -261,15 +262,16 @@ def _triangle_neighbors(tris, npts):
return neighbor_tri


def _triangle_coords(r, geom, best):
@jit()
def _triangle_coords(r, best, r1, nn, r12, r13, a, b, c):
"""Get coordinates of a vertex projected to a triangle."""
r1 = geom['r1'][best]
tri_nn = geom['nn'][best]
r12 = geom['r12'][best]
r13 = geom['r13'][best]
a = geom['a'][best]
b = geom['b'][best]
c = geom['c'][best]
r1 = r1[best]
tri_nn = nn[best]
r12 = r12[best]
r13 = r13[best]
a = a[best]
b = b[best]
c = c[best]
rr = r - r1
z = np.sum(rr * tri_nn)
v1 = np.sum(rr * r12)
Expand All @@ -284,23 +286,16 @@ def _project_onto_surface(rrs, surf, project_rrs=False, return_nn=False,
method='accurate'):
"""Project points onto (scalp) surface."""
surf_geom = _get_tri_supp_geom(surf)
coords = np.empty((len(rrs), 3))
tri_idx = np.empty((len(rrs),), int)
if method == 'accurate':
for ri, rr in enumerate(rrs):
# Get index of closest tri on scalp BEM to electrode position
tri_idx[ri] = _find_nearest_tri_pt(rr, surf_geom)[2]
# Calculate a linear interpolation between the vertex values to
# get coords of pt projected onto closest triangle
coords[ri] = _triangle_coords(rr, surf_geom, tri_idx[ri])
weights = np.array([1. - coords[:, 0] - coords[:, 1], coords[:, 0],
coords[:, 1]])
out = (weights, tri_idx)
pt_tris = np.empty((0,), int)
pt_lens = np.zeros(len(rrs) + 1, int)
out = _find_nearest_tri_pts(rrs, pt_tris, pt_lens,
reproject=True, **surf_geom)
if project_rrs: #
out += (einsum('ij,jik->jk', weights,
surf['rr'][surf['tris'][tri_idx]]),)
out += (einsum('ij,ijk->ik', out[0],
surf['rr'][surf['tris'][out[1]]]),)
if return_nn:
out += (surf_geom['nn'][tri_idx],)
out += (surf_geom['nn'][out[1]],)
else: # nearest neighbor
assert project_rrs
idx = _compute_nearest(surf['rr'], rrs)
Expand Down Expand Up @@ -986,8 +981,8 @@ def read_morph_map(subject_from, subject_to, subjects_dir=None, xhemi=False,
if op.exists(fname):
return _read_morph_map(fname, subject_from, subject_to)
# if file does not exist, make it
warn('Morph map "%s" does not exist, creating it and saving it to '
'disk (this may take a few minutes)' % fname)
logger.info('Morph map "%s" does not exist, creating it and saving it to '
'disk' % fname)
logger.info(log_msg % (subject_from, subject_to))
mmap_1 = _make_morph_map(subject_from, subject_to, subjects_dir, xhemi)
if subject_to == subject_from:
Expand Down Expand Up @@ -1065,6 +1060,7 @@ def _write_morph_map(fname, subject_from, subject_to, mmap_1, mmap_2):
end_file(fid)


@jit()
def _get_tri_dist(p, q, p0, q0, a, b, c, dist):
"""Get the distance to a triangle edge."""
p1 = p - p0
Expand All @@ -1073,19 +1069,19 @@ def _get_tri_dist(p, q, p0, q0, a, b, c, dist):
out += q1 * q1 * b
out += p1 * q1 * c
out += dist * dist
return np.sqrt(out, out=out)
return np.sqrt(out)


def _get_tri_supp_geom(surf):
"""Create supplementary geometry information using tris and rrs."""
r1 = surf['rr'][surf['tris'][:, 0], :]
r12 = surf['rr'][surf['tris'][:, 1], :] - r1
r13 = surf['rr'][surf['tris'][:, 2], :] - r1
r1213 = np.array([r12, r13]).swapaxes(0, 1)
r1213 = np.ascontiguousarray(np.array([r12, r13]).swapaxes(0, 1))
a = einsum('ij,ij->i', r12, r12)
b = einsum('ij,ij->i', r13, r13)
c = einsum('ij,ij->i', r12, r13)
mat = np.rollaxis(np.array([[b, -c], [-c, a]]), 2)
mat = np.ascontiguousarray(np.rollaxis(np.array([[b, -c], [-c, a]]), 2))
norm = (a * b - c * c)
norm[norm == 0] = 1. # avoid divide by zero
mat /= norm[:, np.newaxis, np.newaxis]
Expand Down Expand Up @@ -1137,19 +1133,21 @@ def _make_morph_map_hemi(subject_from, subject_to, subjects_dir, reg_from,
_normalize_vectors(to_rr)

# from surface: get nearest neighbors, find triangles for each vertex
nn_pts_idx = _compute_nearest(from_rr, to_rr)
nn_pts_idx = _compute_nearest(from_rr, to_rr, method='cKDTree')
from_pt_tris = _triangle_neighbors(from_tri, len(from_rr))
from_pt_tris = [from_pt_tris[pt_idx] for pt_idx in nn_pts_idx]
from_pt_tris = [from_pt_tris[pt_idx].astype(int) for pt_idx in nn_pts_idx]
from_pt_lens = np.cumsum([0] + [len(x) for x in from_pt_tris])
from_pt_tris = np.concatenate(from_pt_tris)
assert from_pt_tris.ndim == 1
assert from_pt_lens[-1] == len(from_pt_tris)

# find triangle in which point lies and assoc. weights
tri_inds = []
weights = []
tri_geom = _get_tri_supp_geom(dict(rr=from_rr, tris=from_tri))
for pt_tris, to_pt in zip(from_pt_tris, to_rr):
p, q, idx, dist = _find_nearest_tri_pt(to_pt, tri_geom, pt_tris,
run_all=False)
tri_inds.append(idx)
weights.append([1. - (p + q), p, q])
weights, tri_inds = _find_nearest_tri_pts(
to_rr, from_pt_tris, from_pt_lens, run_all=False, reproject=False,
**tri_geom)

nn_idx = from_tri[tri_inds]
weights = np.array(weights)
Expand All @@ -1160,7 +1158,10 @@ def _make_morph_map_hemi(subject_from, subject_to, subjects_dir, reg_from,
return this_map


def _find_nearest_tri_pt(rr, tri_geom, pt_tris=None, run_all=True):
@jit(parallel=True)
def _find_nearest_tri_pts(rrs, pt_triss, pt_lens,
a, b, c, nn, r1, r12, r13, r1213, mat,
run_all=True, reproject=False):
"""Find nearest point mapping to a set of triangles.
If run_all is False, if the point lies within a triangle, it stops.
Expand All @@ -1179,60 +1180,79 @@ def _find_nearest_tri_pt(rr, tri_geom, pt_tris=None, run_all=True):
# qq = (aas * v2s - ccs * v1s) / dets
# pqs = np.array(pp, qq)

# This einsum is equivalent to doing:
# pqs = np.array([np.dot(x, y) for x, y in zip(r1213, r1-to_pt)])
if pt_tris is None: # use all points
pt_tris = slice(len(tri_geom['r1']))
rrs = rr - tri_geom['r1'][pt_tris]
tri_nn = tri_geom['nn'][pt_tris]
vect = einsum('ijk,ik->ij', tri_geom['r1213'][pt_tris], rrs)
mats = tri_geom['mat'][pt_tris]
# This einsum is equivalent to doing:
# pqs = np.array([np.dot(m, v) for m, v in zip(mats, vect)]).T
pqs = einsum('ijk,ik->ji', mats, vect)
found = False
dists = np.sum(rrs * tri_nn, axis=1)

# There can be multiple (sadness), find closest
idx = np.where(np.all(pqs >= 0., axis=0))[0]
idx = idx[np.where(np.all(pqs[:, idx] <= 1., axis=0))[0]]
idx = idx[np.where(np.sum(pqs[:, idx], axis=0) < 1.)[0]]
dist = np.inf
if len(idx) > 0:
found = True
pt = idx[np.argmin(np.abs(dists[idx]))]
p, q = pqs[:, pt]
dist = dists[pt]
weights = np.empty((len(rrs), 3))
tri_idx = np.empty(len(rrs), np.int64)
for ri in prange(len(rrs)):
rr = np.reshape(rrs[ri], (1, 3))
start, stop = pt_lens[ri:ri + 2]
if start == stop == 0: # use all
drs = rr - r1
tri_nn = nn
mats = mat
r1213s = r1213
reindex = False
else:
pt_tris = pt_triss[start:stop]
drs = rr - r1[pt_tris]
tri_nn = nn[pt_tris]
mats = mat[pt_tris]
r1213s = r1213[pt_tris]
reindex = True
use = np.ones(len(drs), np.int64)
pqs = np.empty((len(drs), 2))
dists = np.empty(len(drs))
dist = np.inf
found = False
for ii in range(len(drs)):
pqs[ii] = np.dot(mats[ii], np.dot(r1213s[ii], drs[ii]))
dists[ii] = np.dot(drs[ii], tri_nn[ii])
pp, qq = pqs[ii]
if pp >= 0 and qq >= 0 and pp <= 1 and qq <= 1 and pp + qq < 1:
found = True
use[ii] = False
if np.abs(dists[ii]) < np.abs(dist):
p, q, pt, dist = pp, qq, ii, dists[ii]
# re-reference back to original numbers
if not isinstance(pt_tris, slice):
if found and reindex:
pt = pt_tris[pt]

if found is False or run_all is True:
# don't include ones that we might have found before
# these are the ones that we want to check thesides of
s = np.setdiff1d(np.arange(dists.shape[0]), idx)
# Tough: must investigate the sides
use_pt_tris = s if isinstance(pt_tris, slice) else pt_tris[s]
pp, qq, ptt, distt = _nearest_tri_edge(use_pt_tris, rr, pqs[:, s],
dists[s], tri_geom)
if np.abs(distt) < np.abs(dist):
p, q, pt, dist = pp, qq, ptt, distt
return p, q, pt, dist
if not found or run_all:
# don't include ones that we might have found before
# these are the ones that we want to check the sides of
s = np.where(use)[0]
# Tough: must investigate the sides
if reindex:
use_pt_tris = pt_tris[s].astype(np.int64)
else:
use_pt_tris = s.astype(np.int64)
pp, qq, ptt, distt = _nearest_tri_edge(
use_pt_tris, rr[0], pqs[s], dists[s], a, b, c)
if np.abs(distt) < np.abs(dist):
p, q, pt, dist = pp, qq, ptt, distt
w = (1 - p - q, p, q)
if reproject:
# Calculate a linear interpolation between the vertex values to
# get coords of pt projected onto closest triangle
coords = _triangle_coords(rr[0], pt, r1, nn, r12, r13, a, b, c)
w = (1. - coords[0] - coords[1], coords[0], coords[1])
weights[ri] = w
tri_idx[ri] = pt
return weights, tri_idx


def _nearest_tri_edge(pt_tris, to_pt, pqs, dist, tri_geom):
@jit()
def _nearest_tri_edge(pt_tris, to_pt, pqs, dist, a, b, c):
"""Get nearest location from a point to the edge of a set of triangles."""
# We might do something intelligent here. However, for now
# it is ok to do it in the hard way
aa = tri_geom['a'][pt_tris]
bb = tri_geom['b'][pt_tris]
cc = tri_geom['c'][pt_tris]
pp = pqs[0]
qq = pqs[1]
aa = a[pt_tris]
bb = b[pt_tris]
cc = c[pt_tris]
pp = pqs[:, 0]
qq = pqs[:, 1]
# Find the nearest point from a triangle:
# Side 1 -> 2
p0 = np.minimum(np.maximum(pp + 0.5 * (qq * cc) / aa,
0.0), 1.0)
p0 = np.minimum(np.maximum(pp + 0.5 * (qq * cc) / aa, 0.0), 1.0)
q0 = np.zeros_like(p0)
# Side 2 -> 3
t1 = (0.5 * ((2.0 * aa - cc) * (1.0 - pp) +
Expand All @@ -1248,9 +1268,9 @@ def _nearest_tri_edge(pt_tris, to_pt, pqs, dist, tri_geom):
dist0 = _get_tri_dist(pp, qq, p0, q0, aa, bb, cc, dist)
dist1 = _get_tri_dist(pp, qq, p1, q1, aa, bb, cc, dist)
dist2 = _get_tri_dist(pp, qq, p2, q2, aa, bb, cc, dist)
pp = np.r_[p0, p1, p2]
qq = np.r_[q0, q1, q2]
dists = np.r_[dist0, dist1, dist2]
pp = np.concatenate((p0, p1, p2))
qq = np.concatenate((q0, q1, q2))
dists = np.concatenate((dist0, dist1, dist2))
ii = np.argmin(np.abs(dists))
p, q, pt, dist = pp[ii], qq[ii], pt_tris[ii % len(pt_tris)], dists[ii]
return p, q, pt, dist
Expand Down
7 changes: 5 additions & 2 deletions mne/tests/test_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ def test_make_morph_maps():
('fsaverage_ds', 'sample_ds', False),
('fsaverage_ds', 'fsaverage_ds', True)):
# trigger the creation of morph-maps dir and create the map
with pytest.warns(None):
with catch_logging() as log:
mmap = read_morph_map(subject_from, subject_to, tempdir,
xhemi=xhemi)
xhemi=xhemi, verbose=True)
log = log.getvalue()
assert 'does not exist' in log
assert 'Creating' in log
mmap2 = read_morph_map(subject_from, subject_to, subjects_dir,
xhemi=xhemi)
assert_equal(len(mmap), len(mmap2))
Expand Down
Loading

0 comments on commit 9351930

Please sign in to comment.