Skip to content

Commit

Permalink
Update visualisation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
wtclarke committed Jun 26, 2024
1 parent 42be86d commit 1664082
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 66 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
This document contains the nifti_mrs_tools release history in reverse chronological order.

1.2.2 (WIP)
-------------------------------
1.2.2 (Wednesday 26th June 2024)
--------------------------------
- Added `--full-hdr` argumet to `mrs_tools info` which enables printing of the full header extension.
- Improved NIfTI-MRS object inspection in python interface.
- Added `.plot()` method to `NIFTI_MRS` objects. This matches the behaviour of `mrs_tools vis`.

1.2.1 (Thursday 4th April 2024)
-------------------------------
Expand Down
80 changes: 16 additions & 64 deletions src/mrs_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,72 +187,15 @@ def vis(args):
:type args: Namespace
"""
try:
from fsl_mrs.utils.plotting import plot_spectrum, plot_spectra
from fsl_mrs.utils.mrs_io import read_FID, read_basis
from fsl_mrs.utils.mrs_io import read_basis
from fsl_mrs.utils.mrs_io.main import FileNotRecognisedError
import matplotlib.pyplot as plt
from fsl_mrs.utils.preproc import nifti_mrs_proc
except ImportError:
raise ImportError(
"mrs_tools vis requires FSL-MRS tools to be installed. "
"See fsl-mrs.com for installation instructions.")

import numpy as np
import nibabel as nib

# Single nifti file
def vis_nifti_mrs(file):
data = read_FID(file)

if data.ndim > 4 \
and 'DIM_COIL' in data.dim_tags\
and args.display_dim != 'DIM_COIL':
print('Performing coil combination')
data = nifti_mrs_proc.coilcombine(data)

def average_dim_if_multiple(dd, dim):
"""Averages a dimension if non-singleton"""
if dim is None:
# Protect against loss of dimension during process.
return dd
if dd.shape[dd.dim_position(dim)] > 1:
print(f'Averaging {dim}')
return nifti_mrs_proc.average(dd, dim)
else:
return dd

if np.prod(data.shape[:3]) == 1:
# SVS
if args.display_dim:
for dim in data.dim_tags:
if dim is None:
continue
if dim != args.display_dim:
data = average_dim_if_multiple(data, dim)
fig = plot_spectra(data.mrs(), ppmlim=args.ppmlim, plot_avg=args.no_mean)

else:
for dim in data.dim_tags:
data = average_dim_if_multiple(data, dim)
fig = plot_spectrum(data.mrs(), ppmlim=args.ppmlim)

if args.save is not None:
fig.savefig(args.save)
else:
plt.show()

else:
for dim in data.dim_tags:
data = average_dim_if_multiple(data, dim)

mrsi = data.mrs()
if args.mask is not None:
mask_hdr = nib.load(args.mask)
mask = np.asanyarray(mask_hdr.dataobj)
if mask.ndim == 2:
mask = np.expand_dims(mask, 2)
mrsi.set_mask(mask)
mrsi.plot(ppmlim=args.ppmlim)
from nifti_mrs.nifti_mrs import NIFTI_MRS

# Some logic to figure out what we are dealing with
p = args.file
Expand Down Expand Up @@ -281,18 +224,27 @@ def average_dim_if_multiple(dd, dim):
' NIFTI-MRS file, not a directory (unless'
' it contains basis files).')

elif p.is_file():
vis_nifti_mrs(p)

else:
try:
vis_nifti_mrs(p)
except FileNotRecognisedError as exc:
data = NIFTI_MRS(p)
except FileNotFoundError as exc:
raise FileNotFoundError(
f"No file or directory '{p}' found."
" Please specify correct file extension (e.g. nii.gz) if there is one.")\
from exc

fig = data.plot(
display_dim=args.display_dim,
ppmlim=args.ppmlim,
plot_avg=args.no_mean,
mask=args.mask,
legend=True)

if args.save is not None:
fig.savefig(args.save)
else:
plt.show()


def merge(args):
"""Merges one or more NIfTI-MRS files along a specified dimension
Expand Down
10 changes: 10 additions & 0 deletions src/nifti_mrs/nifti_mrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,13 @@ def convert_to_tuples(dict_list):
tvar_array = np.asarray(tvar_tuple, dtype=object).reshape(np.prod(self.shape[4:]), len(tvar_tuple[0]))

return tvar_dict2, tvar_tuple2.reshape(self.shape[4:]), tvar_array

def plot(self, display_dim=None, ppmlim=None, plot_avg=False, mask=None, legend=True):
from nifti_mrs.vis import vis_nifti_mrs
return vis_nifti_mrs(
self,
display_dim=display_dim,
ppmlim=ppmlim,
plot_avg=plot_avg,
mask=mask,
legend=legend)
88 changes: 88 additions & 0 deletions src/nifti_mrs/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
try:
from fsl_mrs.utils.plotting import plot_spectrum, plot_spectra
from fsl_mrs.core.mrs import MRS
from fsl_mrs.core.mrsi import MRSI
from fsl_mrs.utils.preproc import nifti_mrs_proc
from fsl_mrs.utils import constants
except ImportError:
raise ImportError(
"NIfTI-MRS visualisation requires FSL-MRS tools to be installed. "
"See fsl-mrs.com for installation instructions.")

import numpy as np
import nibabel as nib


def vis_nifti_mrs(data, display_dim=None, ppmlim=None, plot_avg=False, mask=None, legend=True):

if ppmlim is None:
nuc_info = constants.nucleus_constants(data.nucleus[0])
if nuc_info.ppm_range:
ppmlim = nuc_info.ppm_range
else:
ppmlim = (None, None)

if data.ndim > 4 \
and 'DIM_COIL' in data.dim_tags\
and display_dim != 'DIM_COIL':
print('Performing coil combination')
data = nifti_mrs_proc.coilcombine(data)

def average_dim_if_multiple(dd, dim):
"""Averages a dimension if non-singleton"""
if dim is None:
# Protect against loss of dimension during process.
return dd
if dd.shape[dd.dim_position(dim)] > 1:
print(f'Averaging {dim}')
return nifti_mrs_proc.average(dd, dim)
else:
return dd

if np.prod(data.shape[:3]) == 1:
# SVS
if display_dim:
for dim in data.dim_tags:
if dim is None:
continue
if dim != display_dim:
data = average_dim_if_multiple(data, dim)
mrs = []
for fid, _ in data.iterate_over_dims():
mrs.append(
MRS(
fid.squeeze(),
bw=data.bandwidth,
cf=data.spectrometer_frequency[0],
nucleus=data.nucleus[0]))
fig = plot_spectra(mrs, ppmlim=ppmlim, plot_avg=plot_avg, legend=legend)

else:
for dim in data.dim_tags:
data = average_dim_if_multiple(data, dim)
mrs = MRS(
data[:].squeeze(),
bw=data.bandwidth,
cf=data.spectrometer_frequency[0],
nucleus=data.nucleus[0])
fig = plot_spectrum(mrs, ppmlim=ppmlim)

return fig

else:
for dim in data.dim_tags:
data = average_dim_if_multiple(data, dim)

mrsi = MRSI(
data[:],
bw=data.bandwidth,
cf=data.spectrometer_frequency[0],
nucleus=data.nucleus[0])

if mask is not None:
mask_hdr = nib.load(mask)
mask = np.asanyarray(mask_hdr.dataobj)
if mask.ndim == 2:
mask = np.expand_dims(mask, 2)
mrsi.set_mask(mask)
mrsi.plot(ppmlim=ppmlim)

0 comments on commit 1664082

Please sign in to comment.