diff --git a/doc/changes/devel/12555.newfeature.rst b/doc/changes/devel/12555.newfeature.rst new file mode 100644 index 00000000000..1b740088a85 --- /dev/null +++ b/doc/changes/devel/12555.newfeature.rst @@ -0,0 +1 @@ +Add diffusion imaging to :ref:`tut-working-with-seeg` including adding :meth:`mne.viz.Brain.add_streamline` and :meth:`mne.viz.Brain.remove_streamlines` to visualize the fiber tracts, by `Alex Rockhill`_. \ No newline at end of file diff --git a/mne/datasets/config.py b/mne/datasets/config.py index a2f2d7781b7..597921c7383 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -90,7 +90,7 @@ # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( testing="0.152", - misc="0.27", + misc="0.29", phantom_kit="0.2", ucl_opm_auditory="0.2", ) @@ -131,7 +131,7 @@ ) MNE_DATASETS["misc"] = dict( archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', - hash="md5:e343d3a00cb49f8a2f719d14f4758afe", + hash="md5:19535192331d9e4e99d8886028c1f447", url=( "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" f'{RELEASES["misc"]}' diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 7d5c858815e..bd773aa747e 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -225,6 +225,8 @@ class Brain: +-------------------------------------+--------------+---------------+ | :meth:`add_skull` | | ✓ | +-------------------------------------+--------------+---------------+ + | :meth:`add_streamline` | | ✓ | + +-------------------------------------+--------------+---------------+ | :meth:`add_text` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_volume_labels` | | ✓ | @@ -253,6 +255,8 @@ class Brain: +-------------------------------------+--------------+---------------+ | :meth:`remove_skull` | | ✓ | +-------------------------------------+--------------+---------------+ + | :meth:`remove_streamlines` | | ✓ | + +-------------------------------------+--------------+---------------+ | :meth:`remove_text` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_volume_labels` | | ✓ | @@ -2505,6 +2509,66 @@ def remove_skull(self): """Remove skull objects from the rendered scene.""" self._remove("skull", render=True) + @fill_doc + def add_streamline( + self, + streamline, + line_width=1, + color="red", + scalars=None, + colormap=None, + vmin=None, + vmax=None, + alpha=1, + ): + """Add a streamlines to render fiber tracts. + + Parameters + ---------- + streamline : array shape=(n_points, 3) + An array with 3D points forming a line in units of m. + line_width : int + The width of the line. + color : list + A list with entries of anything matplotlib accepts: + string, RGB, hex, etc. + scalars : list + A list of scalar values associated with each vertex of + the streamline. + %(colormap)s + vmin : None | float + The minimum value for color scaling. + vmax : None | float + The maximum value for color scaling. + %(alpha)s + + Notes + ----- + .. versionadded:: 0.24 + """ + color = _to_rgb(color) + + for _ in self._iter_views("vol"): + actor, _ = self._renderer.line( + streamline * (1e3 if self._units == "mm" else 1), + color=color, + opacity=alpha, + line_width=line_width, + scalars=scalars, + colormap=colormap, + vmin=vmin, + vmax=vmax, + reset_camera=False, + render=False, + ) + self._add_actor("streamlines", actor) + + self._renderer._update() + + def remove_streamlines(self): + """Remove streamline objects from the rendered scene.""" + self._remove("streamlines", render=True) + @fill_doc def add_volume_labels( self, diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index c31023401ed..f8a06383ca7 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -194,6 +194,66 @@ def mesh( """ pass + @abstractclassmethod + def line( + self, + line, + colors, + opacity=1.0, + backface_culling=False, + scalars=None, + colormap=None, + vmin=None, + vmax=None, + interpolate_before_map=True, + line_width=1.0, + polygon_offset=None, + **kwargs, + ): + """Add a mesh in the scene. + + Parameters + ---------- + line : list + A list of array-like points defining the lines. + color : tuple | str + The color of the mesh as a tuple (red, green, blue) of float + values between 0 and 1 or a valid color name (i.e. 'white' + or 'w'). + opacity : float + The opacity of the mesh. + shading : bool + If True, enable the mesh shading. + backface_culling : bool + If True, enable backface culling on the mesh. + scalars : ndarray, shape (n_vertices,) + The scalar valued associated to the vertices. + vmin : float | None + vmin is used to scale the colormap. + If None, the min of the data will be used. + vmax : float | None + vmax is used to scale the colormap. + If None, the max of the data will be used. + colormap : str | np.ndarray | matplotlib.colors.Colormap | None + The colormap to use. + interpolate_before_map : + Enabling makes for a smoother scalars display. Default is True. + When False, OpenGL will interpolate the mapped colors which can + result is showing colors that are not present in the color map. + line_width : int + The width of the line. + polygon_offset : float + If not None, the factor used to resolve coincident topology. + kwargs : args + The arguments to pass to triangular_mesh + + Returns + ------- + line : + Handle of the line in the scene. + """ + pass + @abstractclassmethod def contour( self, diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 4f395344f41..784b8adb5f2 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -455,6 +455,37 @@ def mesh( **kwargs, ) + def line( + self, + line, + color, + opacity=1.0, + backface_culling=False, + scalars=None, + colormap=None, + vmin=None, + vmax=None, + interpolate_before_map=True, + line_width=1.0, + polygon_offset=None, + **kwargs, + ): + return self.polydata( + mesh=pyvista.MultipleLines(line), + color=color, + opacity=opacity, + normals=np.zeros_like(np.array(line)), + backface_culling=backface_culling, + scalars=scalars, + colormap=colormap, + vmin=vmin, + vmax=vmax, + interpolate_before_map=interpolate_before_map, + line_width=line_width, + polygon_offset=polygon_offset, + **kwargs, + ) + def contour( self, surface, diff --git a/mne/viz/backends/tests/test_renderer.py b/mne/viz/backends/tests/test_renderer.py index b20bb6e4865..5691d977186 100644 --- a/mne/viz/backends/tests/test_renderer.py +++ b/mne/viz/backends/tests/test_renderer.py @@ -101,6 +101,9 @@ def test_3d_backend(renderer): txt_text = "renderer" txt_size = 14 + line_data = [[0, 0, 0], [1, 1, 1]] + line_color = "red" + cam_distance = 5 * tet_size # init scene @@ -123,6 +126,14 @@ def test_3d_backend(renderer): ) rend.remove_mesh(mesh_data) + # use lines + line_actor = rend.line( + line_data, + color=line_color, + line_width=5, + ) + rend.remove_mesh(line_actor) + # use contour rend.contour( surface=ct_surface, scalars=ct_scalars, contours=ct_levels, kind="line" diff --git a/tutorials/clinical/20_seeg.py b/tutorials/clinical/20_seeg.py index ea56ea8a688..942df22e74d 100644 --- a/tutorials/clinical/20_seeg.py +++ b/tutorials/clinical/20_seeg.py @@ -38,7 +38,20 @@ # %% +import matplotlib.pyplot as plt +import nibabel as nib import numpy as np +from dipy.core.gradients import gradient_table +from dipy.data import default_sphere +from dipy.denoise.patch2self import patch2self +from dipy.direction import DeterministicMaximumDirectionGetter +from dipy.direction.peaks import peaks_from_model +from dipy.reconst.dti import TensorModel +from dipy.segment.mask import median_otsu +from dipy.tracking.local_tracking import LocalTracking +from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion +from dipy.tracking.streamline import Streamlines +from dipy.tracking.utils import seeds_from_mask import mne from mne.datasets import fetch_fsaverage @@ -157,6 +170,138 @@ fig, ax = mne.viz.plot_channel_labels_circle(labels, colors, picks=picks) fig.text(0.3, 0.9, "Anatomical Labels", color="white") +# %% +# For electrode contacts in white matter, it can be helpful to visualize +# fiber tracts that pass nearby as well. For that we need to do fiber +# tracking on diffusion MR data. + +# load the diffusion MR data +dwi = nib.load(misc_path / "seeg" / "sample_seeg_dwi.nii.gz") +bvals = np.loadtxt(misc_path / "seeg" / "sample_seeg_dwi.bval") +bvecs = np.loadtxt(misc_path / "seeg" / "sample_seeg_dwi.bvec") +gtab = gradient_table(bvals, bvecs) + +# use B0 diffusion data to align with the T1 +b0_idx = tuple(np.where(bvals < 50)[0]) +dwi_masked, mask = median_otsu(np.array(dwi.dataobj), vol_idx=b0_idx) + +fig, ax = plt.subplots() +ax.imshow(np.rot90(dwi_masked[65, ..., 0]), aspect="auto") + +t1 = nib.load(misc_path / "seeg" / "sample_seeg" / "mri" / "T1.mgz") +dwi_b0_register = nib.Nifti1Image(dwi_masked[..., b0_idx].mean(axis=-1), dwi.affine) + +# %% +# The code below was run once to find the registration matrix, but to +# save computer resources when building the documentation, we won't +# run it every time:: +# +# reg_affine = mne.transforms.compute_volume_registration( +# moving=dwi_b0_register, static=t1, pipeline='rigids') + +reg_affine = np.array( + [ + [0.99804908, -0.05071631, 0.03641263, 1.36631239], + [0.049687, 0.99835418, 0.0286378, 36.79845134], + [-0.03780511, -0.02677269, 0.99892642, 8.30634414], + [0.0, 0.0, 0.0, 1.0], + ] +) +reg_affine_inv = np.linalg.inv(reg_affine) + +# use registration to move the white matter mask computed +# by freesurfer to the diffusion space +wm = nib.load(misc_path / "seeg" / "sample_seeg" / "mri" / "wm.mgz") +wm_data = np.array(wm.dataobj) +wm_mask = (wm_data == 109) | (wm_data == 110) # white matter values +wm = nib.MGHImage(wm_mask.astype(np.float32), wm.affine) +del wm_data, wm_mask + +# apply the backward registration by using the inverse +wm_dwi = mne.transforms.apply_volume_registration( + moving=wm, static=dwi_b0_register, reg_affine=reg_affine_inv +) + +# check that white matter is aligned properly +fig, ax = plt.subplots() +ax.imshow(np.rot90(dwi_b0_register.dataobj[56]), aspect="auto") +ax.imshow(np.rot90(wm_dwi.dataobj[56]), aspect="auto", cmap="hot", alpha=0.5) + +# now, preprocess the diffusion data to remove noise and do +# fiber tracking +denoised = patch2self(dwi_masked, bvals) + +# %% +# Optionally, you can also remove Gibbs artifact:: +# +# from dipy.denoise.gibbs import gibbs_removal +# denoised = gibbs_removal(denoised) + +# %% +# You may also want to do the following, but it registers each direction +# of the diffusion image to the T1, so it takes a lot of computational +# resources so we'll skip it for now:: +# +# from dipy.align import motion_correction +# denoised, _ = motion_correction(denoised, gtab, dwi.affine, b0_ref=0) +# denoised = np.array(denoised.dataobj) + +# compute diffusion tensor imaging to find the peak direction +# for each voxel +tenmodel = TensorModel(gtab) +tenfit = tenmodel.fit(denoised) +pam = peaks_from_model( + tenmodel, + denoised, + default_sphere, + relative_peak_threshold=0.5, + min_separation_angle=25, + mask=wm_dwi.dataobj, + legacy=False, +) + +# do fiber tracking +stopping_criterion = ThresholdStoppingCriterion( + pam.gfa, # use generalized fractional anisotropy from the DTI model + 0.25, # threshold for stopping is when FA goes below 0.25 (default) +) +dg = DeterministicMaximumDirectionGetter.from_shcoeff( + pam.shm_coeff, # use spherical harmonic coefficients from the DTI model + max_angle=30.0, # max angle fiber can change at each voxel + sphere=default_sphere, # use default sphere + sh_to_pmf=True, # speeds up computations, takes more memory + legacy=False, # use newer version +) +# use the white matter mask to seed where the fibers start, +# with 1 mm density in all three dimensions +seeds = seeds_from_mask(wm_dwi.dataobj, dwi.affine, density=(1, 1, 1)) +# generate streamlines to represent tracts using the stopping +# criteria, direction getter and seeds +streamline_generator = LocalTracking( + dg, stopping_criterion, seeds, dwi.affine, step_size=0.5 +) +streamlines = Streamlines(streamline_generator) + +# move streamlines from diffusion space to T1 anatomical space, +# only keep long streamlines +streamlines = [ + mne.transforms.apply_trans(reg_affine_inv, streamline) + for streamline in streamlines + if len(streamline) > 10 +] + +# now convert from scanner RAS to surface RAS +ras2mri = mne.transforms.combine_transforms( + mne.transforms.Transform("ras", "mri_voxel", t1.header.get_ras2vox()), + mne.transforms.Transform("mri_voxel", "mri", t1.header.get_vox2ras_tkr()), + fro="ras", + to="mri", +) +streamlines = [ + mne.transforms.apply_trans(ras2mri, streamline) / 1000 # mm -> m + for streamline in streamlines +] + # %% # Now, let's the electrodes and a few regions of interest that the contacts # of the electrode are proximal to. @@ -191,6 +336,22 @@ figure=fig, ) brain.add_volume_labels(aseg="aparc+aseg", labels=labels) + +# find streamlines near LSMA1 +montage = epochs.get_montage() +montage.apply_trans(mne.transforms.invert_transform(trans)) # head -> mri +ch_pos = montage.get_positions()["ch_pos"] + +thresh = 0.03 # pick streamlines within 30 mm +streamlines_pick = [ + streamline + for streamline in streamlines + if np.linalg.norm(streamline - ch_pos["LPM 1"]).min() < thresh +] + +for streamline in streamlines_pick: + brain.add_streamline(streamline, color="white") + brain.show_view(azimuth=120, elevation=90, distance=0.25) # %%