Skip to content

Commit

Permalink
[MRG+1] Add function for random parcellation (mne-tools#4955)
Browse files Browse the repository at this point in the history
* Add functions to create random parcellation

* add test for random_parcellation

* add random_parcellation to python_reference

* fixe pep8

* fixe pep8 mne-tools#2

* fixe pep mne-tools#3

* add random_state, small corrections

* fixes random_state

* FIX: Alphabetical
  • Loading branch information
makkostya authored and britta-wstnr committed Jul 5, 2018
1 parent bdde82f commit caba372
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/python_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ Source Space Data
label_sign_flip
morph_data
morph_data_precomputed
random_parcellation
read_labels_from_annot
read_dipole
read_label
Expand Down
2 changes: 1 addition & 1 deletion mne/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from .evoked import Evoked, EvokedArray, read_evokeds, write_evokeds, combine_evoked
from .label import (read_label, label_sign_flip,
write_label, stc_to_label, grow_labels, Label, split_label,
BiHemiLabel, read_labels_from_annot, write_labels_to_annot)
BiHemiLabel, read_labels_from_annot, write_labels_to_annot, random_parcellation)
from .misc import parse_config, read_reject_parameters
from .coreg import (create_default_subject, scale_bem, scale_mri, scale_labels,
scale_source_space)
Expand Down
160 changes: 159 additions & 1 deletion mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import numpy as np
from scipy import linalg, sparse

from .utils import get_subjects_dir, _check_subject, logger, verbose, warn
from .utils import get_subjects_dir, _check_subject, logger, verbose, warn,\
check_random_state
from .source_estimate import (morph_data, SourceEstimate, _center_of_mass,
spatial_src_connectivity)
from .source_space import add_source_space_distances
Expand Down Expand Up @@ -1659,6 +1660,163 @@ def _grow_nonoverlapping_labels(subject, seeds_, extents_, hemis, vertices_,
return labels


def random_parcellation(subject, n_parcel, hemi, subjects_dir=None,
surface='white', random_state=None):
"""Generate random cortex parcellation by growing labels.
This function generates a number of labels which don't intersect and
cover the whole surface. Regions are growing around randomly chosen
seeds.
Parameters
----------
subject : string
Name of the subject as in SUBJECTS_DIR.
n_parcel : int
Total number of cortical parcels.
hemi : str
hemisphere id (ie 'lh', 'rh', 'both'). In the case
of 'both', both hemispheres are processed with (n_parcel // 2)
parcels per hemisphere.
subjects_dir : string
Path to SUBJECTS_DIR if not set in the environment.
surface : string
The surface used to grow the labels, defaults to the white surface.
random_state : None | int | np.random.RandomState
To specify the random generator state.
Returns
-------
labels : list of Label
Random cortex parcellation
"""
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if hemi == 'both':
hemi = ['lh', 'rh']
hemis = np.atleast_1d(hemi)

# load the surfaces and create the distance graphs
tris, vert, dist = {}, {}, {}
for hemi in set(hemis):
surf_fname = op.join(subjects_dir, subject, 'surf', hemi + '.' +
surface)
vert[hemi], tris[hemi] = read_surface(surf_fname)
dist[hemi] = mesh_dist(tris[hemi], vert[hemi])

# create the patches
labels = _cortex_parcellation(subject, n_parcel, hemis, vert, dist,
random_state)

# add a unique color to each label
colors = _n_colors(len(labels))
for label, color in zip(labels, colors):
label.color = color

return labels


def _cortex_parcellation(subject, n_parcel, hemis, vertices_, graphs,
random_state=None):
"""Random cortex parcellation."""
labels = []
rng = check_random_state(random_state)
for hemi in set(hemis):
parcel_size = len(hemis) * len(vertices_[hemi]) // n_parcel
graph = graphs[hemi] # distance graph
n_vertices = len(vertices_[hemi])

# prepare parcellation
parc = np.full(n_vertices, -1, dtype='int32')

# initialize active sources
s = rng.choice(range(n_vertices))
label_idx = 0
edge = [s] # queue of vertices to process
parc[s] = label_idx
label_size = 1
rest = len(parc) - 1
# grow from sources
while rest:
# if there are not free neighbors, start new parcel
if not edge:
rest_idx = np.where(parc < 0)[0]
s = rng.choice(rest_idx)
edge = [s]
label_idx += 1
label_size = 1
parc[s] = label_idx
rest -= 1

vert_from = edge.pop(0)

# add neighbors within allowable distance
row = graph[vert_from, :]
for vert_to, dist in zip(row.indices, row.data):
vert_to_label = parc[vert_to]

# abort if the vertex is already occupied
if vert_to_label >= 0:
continue

# abort if outside of extent
if label_size > parcel_size:
label_idx += 1
label_size = 1
edge = [vert_to]
parc[vert_to] = label_idx
rest -= 1
break

# assign label value
parc[vert_to] = label_idx
label_size += 1
edge.append(vert_to)
rest -= 1

# merging small labels
# label connectivity matrix
n_labels = label_idx + 1
label_sizes = np.empty(n_labels, dtype=int)
label_conn = np.zeros([n_labels, n_labels], dtype='bool')
for i in range(n_labels):
vertices = np.nonzero(parc == i)[0]
label_sizes[i] = len(vertices)
neighbor_vertices = graph[vertices, :].indices
neighbor_labels = np.unique(np.array(parc[neighbor_vertices]))
label_conn[i, neighbor_labels] = 1
np.fill_diagonal(label_conn, 0)

# merging
label_id = range(n_labels)
while n_labels > n_parcel // len(hemis):
# smallest label and its smallest neighbor
i = np.argmin(label_sizes)
neighbors = np.nonzero(label_conn[i, :])[0]
j = neighbors[np.argmin(label_sizes[neighbors])]

# merging two labels
label_conn[j, :] += label_conn[i, :]
label_conn[:, j] += label_conn[:, i]
label_conn = np.delete(label_conn, i, 0)
label_conn = np.delete(label_conn, i, 1)
label_conn[j, j] = 0
label_sizes[j] += label_sizes[i]
label_sizes = np.delete(label_sizes, i, 0)
n_labels -= 1
vertices = np.nonzero(parc == label_id[i])[0]
parc[vertices] = label_id[j]
label_id = np.delete(label_id, i, 0)

# convert parc to labels
for i in xrange(n_labels):
vertices = np.nonzero(parc == label_id[i])[0]
name = 'label_' + str(i)
label_ = Label(vertices, hemi=hemi, name=name, subject=subject)
labels.append(label_)

return labels


def _read_annot(fname):
"""Read a Freesurfer annotation from a .annot file.
Expand Down
41 changes: 40 additions & 1 deletion mne/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mne import (read_label, stc_to_label, read_source_estimate,
read_source_spaces, grow_labels, read_labels_from_annot,
write_labels_to_annot, split_label, spatial_tris_connectivity,
read_surface)
read_surface, random_parcellation)
from mne.label import Label, _blend_colors, label_sign_flip
from mne.utils import (_TempDir, requires_sklearn, get_subjects_dir,
run_tests_if_main)
Expand Down Expand Up @@ -768,6 +768,45 @@ def test_grow_labels():
assert_array_equal(l1.vertices, l0.vertices)


@testing.requires_testing_data
def test_random_parcellation():
"""Test generation of random cortical parcellation."""
hemi = 'both'
n_parcel = 50
surface = 'white'
subject = 'sample'
rng = np.random.RandomState(0)

# Parcellation
labels = random_parcellation(subject, n_parcel, hemi, subjects_dir,
surface=surface, random_state=rng)

# test number of labels
assert_equal(len(labels), n_parcel)
if hemi == 'both':
hemi = ['lh', 'rh']
hemis = np.atleast_1d(hemi)
for hemi in set(hemis):
vertices_total = []
for label in labels:
if label.hemi == hemi:
# test that labels are not empty
assert_true(len(label.vertices) > 0)

# vertices of hemi covered by labels
vertices_total = np.append(vertices_total, label.vertices)

# test that labels don't intersect
assert_equal(len(np.unique(vertices_total)), len(vertices_total))

surf_fname = op.join(subjects_dir, subject, 'surf', hemi + '.' +
surface)
vert, _ = read_surface(surf_fname)

# Test that labels cover whole surface
assert_array_equal(np.sort(vertices_total), np.arange(len(vert)))


@testing.requires_testing_data
def test_label_sign_flip():
"""Test label sign flip computation."""
Expand Down

0 comments on commit caba372

Please sign in to comment.