Skip to content

Commit

Permalink
ENH: Don't require specific order for fNIRS (#10642)
Browse files Browse the repository at this point in the history
* MAINT: Refactor with noops

* ENH: Order agnostic

* FIX: Missed one

* CI: Ping

* FIX: Fix nominal

* FIX: Marks

* FIX: Flake

* FIX: Fixes

Co-authored-by: Eric Larson <larsoner@monolith.local>
  • Loading branch information
larsoner and Eric Larson authored May 23, 2022
1 parent a62f9a4 commit f435846
Show file tree
Hide file tree
Showing 18 changed files with 336 additions and 256 deletions.
6 changes: 3 additions & 3 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None):
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _validate_nirs_info

# Returns pick of all nirs and ensures channels are correctly ordered
picks_nirs = _validate_nirs_info(inst.info)
if len(picks_nirs) == 0:
if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
return

# Returns pick of all nirs and ensures channels are correctly ordered
picks_nirs = _validate_nirs_info(inst.info)
nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
Expand Down
25 changes: 24 additions & 1 deletion mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mne.coreg import create_default_subject
from mne.datasets import testing
from mne.fixes import has_numba, _compare_version
from mne.io import read_raw_fif, read_raw_ctf
from mne.io import read_raw_fif, read_raw_ctf, read_raw_nirx, read_raw_snirf
from mne.stats import cluster_level
from mne.utils import (_pl, _assert_no_instances, numerics, Bunch,
_check_qt_version, _TempDir)
Expand All @@ -48,6 +48,17 @@
ctf_dir = op.join(test_path, 'CTF')
fname_ctf_continuous = op.join(ctf_dir, 'testdata_ctf.ds')

nirx_path = test_path / 'NIRx'
snirf_path = test_path / 'SNIRF'
nirsport2 = nirx_path / 'nirsport_v2' / 'aurora_recording _w_short_and_acc'
nirsport2_snirf = (
snirf_path / 'NIRx' / 'NIRSport2' / '1.0.3' /
'2021-05-05_001.snirf')
nirsport2_2021_9 = nirx_path / 'nirsport_v2' / 'aurora_2021_9'
nirsport2_20219_snirf = (
snirf_path / 'NIRx' / 'NIRSport2' / '2021.9' /
'2021-10-01_002.snirf')

# data from mne.io.tests.data
base_dir = op.join(op.dirname(__file__), 'io', 'tests', 'data')
fname_raw_io = op.join(base_dir, 'test_raw.fif')
Expand Down Expand Up @@ -925,3 +936,15 @@ def run(nbexec=nbexec, code=code):

item.runtest = run
return


@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:')
@pytest.fixture(params=(
[nirsport2, nirsport2_snirf, testing._pytest_param()],
[nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()],
))
def nirx_snirf(request):
"""Return a (raw_nirx, raw_snirf) matched pair."""
pytest.importorskip('h5py')
return (read_raw_nirx(request.param[0], preload=True),
read_raw_snirf(request.param[1], preload=True))
13 changes: 8 additions & 5 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# Lars Buitinck <L.J.Buitinck@uva.nl>
# License: BSD

import functools
import inspect
from math import log
import os
from pathlib import Path
import warnings

import numpy as np
Expand Down Expand Up @@ -72,10 +70,15 @@ def _median_complex(data, axis):


# helpers to get function arguments
def _get_args(function, varargs=False):
def _get_args(function, varargs=False, *,
exclude=('var_positional', 'var_keyword')):
params = inspect.signature(function).parameters
args = [key for key, param in params.items()
if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)]
# As of Python 3.10:
# https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
# POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD, VAR_POSITIONAL, KEYWORD_ONLY,
# VAR_KEYWORD
exclude = set(getattr(inspect.Parameter, ex.upper()) for ex in exclude)
args = [key for key, param in params.items() if param.kind not in exclude]
if varargs:
varargs = [param.name for param in params.values()
if param.kind == param.VAR_POSITIONAL]
Expand Down
20 changes: 0 additions & 20 deletions mne/io/nirx/nirx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ..utils import _mult_cal_one
from ..constants import FIFF
from ..meas_info import create_info, _format_dig_points
from ..pick import pick_types
from ...annotations import Annotations
from ..._freesurfer import get_mni_fiducials
from ...transforms import apply_trans, _get_trans
Expand Down Expand Up @@ -459,7 +458,6 @@ def __init__(self, fname, saturated, preload=False, verbose=None):
ch_names.append(list())
annot = Annotations(onset, duration, description, ch_names=ch_names)
self.set_annotations(annot)
self.pick(picks=_nirs_sort_idx(self.info))

def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
"""Read a segment of data from a file.
Expand Down Expand Up @@ -512,21 +510,3 @@ def _convert_fnirs_to_head(trans, fro, to, src_locs, det_locs, ch_locs):
det_locs = apply_trans(mri_head_t, det_locs)
ch_locs = apply_trans(mri_head_t, ch_locs)
return src_locs, det_locs, ch_locs, mri_head_t


def _nirs_sort_idx(info):
# TODO: Remove any actual reordering that is done and just use this
# function to get picks to operate on in an ordered way. This should be
# done by refactoring mne.preprocessing.nirs.nirs._check_channels_ordered
# and this function to make sure the picks we obtain here are in the
# correct order.
nirs_picks = pick_types(info, fnirs=True, exclude=())
other_picks = np.setdiff1d(np.arange(info['nchan']), nirs_picks)
prefixes = [info['ch_names'][pick].split()[0] for pick in nirs_picks]
nirs_names = [info['ch_names'][pick] for pick in nirs_picks]
nirs_sorted = sorted(nirs_names,
key=lambda name: (prefixes.index(name.split()[0]),
name.split(maxsplit=1)[1]))
nirs_picks = nirs_picks[
[nirs_names.index(name) for name in nirs_sorted]]
return np.concatenate((nirs_picks, other_picks))
27 changes: 6 additions & 21 deletions mne/io/nirx/tests/test_nirx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

from mne import pick_types
from mne.datasets.testing import data_path, requires_testing_data
from mne.io import read_raw_nirx, read_raw_snirf
from mne.utils import requires_h5py
from mne.io import read_raw_nirx
from mne.io.tests.test_raw import _test_raw_reader
from mne.preprocessing import annotate_nan
from mne.transforms import apply_trans, _get_trans
from mne.preprocessing.nirs import source_detector_distances,\
short_channels
short_channels, _reorder_nirx
from mne.io.constants import FIFF

testing_path = data_path(download=False)
Expand All @@ -46,31 +45,17 @@
testing_path, 'NIRx', 'nirsport_v1', 'nirx_15_3_recording_w_'
'saturation_on_montage_channels')

# NIRSport2 device using Aurora software and matching snirf file
# NIRSport2 device using Aurora software
nirsport2 = op.join(
testing_path, 'NIRx', 'nirsport_v2', 'aurora_recording _w_short_and_acc')
nirsport2_snirf = op.join(
testing_path, 'SNIRF', 'NIRx', 'NIRSport2', '1.0.3',
'2021-05-05_001.snirf')

nirsport2_2021_9 = op.join(
testing_path, 'NIRx', 'nirsport_v2', 'aurora_2021_9')
snirf_nirsport2_20219 = op.join(
testing_path, 'SNIRF', 'NIRx', 'NIRSport2', '2021.9',
'2021-10-01_002.snirf')


@requires_h5py
@requires_testing_data
@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:')
@pytest.mark.parametrize('fname_nirx, fname_snirf', (
[nirsport2, nirsport2_snirf],
[nirsport2_2021_9, snirf_nirsport2_20219],
))
def test_nirsport_v2_matches_snirf(fname_nirx, fname_snirf):
def test_nirsport_v2_matches_snirf(nirx_snirf):
"""Test NIRSport2 raw files return same data as snirf."""
raw = read_raw_nirx(fname_nirx, preload=True)
raw_snirf = read_raw_snirf(fname_snirf, preload=True)
raw, raw_snirf = nirx_snirf
_reorder_nirx(raw_snirf)
assert raw.ch_names == raw_snirf.ch_names

assert_allclose(raw._data, raw_snirf._data)
Expand Down
6 changes: 1 addition & 5 deletions mne/io/snirf/_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..constants import FIFF
from .._digitization import _make_dig_points
from ...transforms import _frame_to_str, apply_trans
from ..nirx.nirx import _convert_fnirs_to_head, _nirs_sort_idx
from ..nirx.nirx import _convert_fnirs_to_head
from ..._freesurfer import get_mni_fiducials


Expand Down Expand Up @@ -409,10 +409,6 @@ def natural_keys(text):
annot.append(data[:, 0], 1.0, desc.decode('UTF-8'))
self.set_annotations(annot, emit_warning=False)

# MNE requires channels are paired as alternating wavelengths
if len(_validate_nirs_info(self.info, throw_errors=False)) == 0:
self.pick(picks=_nirs_sort_idx(self.info))

# Validate that the fNIRS info is correctly formatted
_validate_nirs_info(self.info)

Expand Down
46 changes: 22 additions & 24 deletions mne/io/snirf/tests/test_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mne.io import read_raw_snirf, read_raw_nirx
from mne.io.tests.test_raw import _test_raw_reader
from mne.preprocessing.nirs import (optical_density, beer_lambert_law,
short_channels, source_detector_distances)
short_channels, source_detector_distances,
_reorder_nirx)
from mne.transforms import apply_trans, _get_trans
from mne.io.constants import FIFF

Expand Down Expand Up @@ -93,23 +94,18 @@ def test_snirf_gowerlabs():
assert len(raw.ch_names) == 216
assert_allclose(raw.info['sfreq'], 10.0)
# we don't force them to be sorted according to a naive split
# (but we do force them to be interleaved, which is tested by beer_lambert
# above)
assert raw.ch_names != sorted(raw.ch_names)
# ... and this file does have a nice logical ordering already
# ... but this file does have a nice logical ordering already
print(raw.ch_names)
assert raw.ch_names == sorted(
raw.ch_names, # use a key which is (source int, detector int)
key=lambda name: (int(name.split()[0].split('_')[0][1:]),
int(name.split()[0].split('_')[1][1:])))
prefixes = [name.split()[0] for name in raw.ch_names]
# TODO: This is actually not the order on disk -- we reorder to ravel as
# S-D then freq, but gowerlabs order is freq then S-D. So hopefully soon
# we can change these lines to check that the first half of prefixes
# matches the second half of prefixes, rather than every-other matching the
# other every-other
assert prefixes[::2] == prefixes[1::2]
prefixes = prefixes[::2]
assert prefixes == ['S1_D1', 'S1_D2', 'S1_D3', 'S1_D4', 'S1_D5', 'S1_D6', 'S1_D7', 'S1_D8', 'S1_D9', 'S1_D10', 'S1_D11', 'S1_D12', 'S2_D1', 'S2_D2', 'S2_D3', 'S2_D4', 'S2_D5', 'S2_D6', 'S2_D7', 'S2_D8', 'S2_D9', 'S2_D10', 'S2_D11', 'S2_D12', 'S3_D1', 'S3_D2', 'S3_D3', 'S3_D4', 'S3_D5', 'S3_D6', 'S3_D7', 'S3_D8', 'S3_D9', 'S3_D10', 'S3_D11', 'S3_D12', 'S4_D1', 'S4_D2', 'S4_D3', 'S4_D4', 'S4_D5', 'S4_D6', 'S4_D7', 'S4_D8', 'S4_D9', 'S4_D10', 'S4_D11', 'S4_D12', 'S5_D1', 'S5_D2', 'S5_D3', 'S5_D4', 'S5_D5', 'S5_D6', 'S5_D7', 'S5_D8', 'S5_D9', 'S5_D10', 'S5_D11', 'S5_D12', 'S6_D1', 'S6_D2', 'S6_D3', 'S6_D4', 'S6_D5', 'S6_D6', 'S6_D7', 'S6_D8', 'S6_D9', 'S6_D10', 'S6_D11', 'S6_D12', 'S7_D1', 'S7_D2', 'S7_D3', 'S7_D4', 'S7_D5', 'S7_D6', 'S7_D7', 'S7_D8', 'S7_D9', 'S7_D10', 'S7_D11', 'S7_D12', 'S8_D1', 'S8_D2', 'S8_D3', 'S8_D4', 'S8_D5', 'S8_D6', 'S8_D7', 'S8_D8', 'S8_D9', 'S8_D10', 'S8_D11', 'S8_D12', 'S9_D1', 'S9_D2', 'S9_D3', 'S9_D4', 'S9_D5', 'S9_D6', 'S9_D7', 'S9_D8', 'S9_D9', 'S9_D10', 'S9_D11', 'S9_D12'] # noqa: E501
raw.ch_names,
# use a key which is (src triplet, freq, src, freq, det)
key=lambda name: (
(int(name.split()[0].split('_')[0][1:]) - 1) // 3,
int(name.split()[1]),
int(name.split()[0].split('_')[0][1:]),
int(name.split()[0].split('_')[1][1:])
))


@requires_testing_data
Expand All @@ -122,13 +118,13 @@ def test_snirf_basic():
assert raw.info['sfreq'] == 12.5

# Test channel naming
assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D1 850",
"S1_D9 760", "S1_D9 850"]
assert raw.info['ch_names'][24:26] == ["S5_D13 760", "S5_D13 850"]
assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D9 760",
"S2_D3 760", "S2_D10 760"]
assert raw.info['ch_names'][24:26] == ['S5_D8 850', 'S5_D13 850']

# Test frequency encoding
assert raw.info['chs'][0]['loc'][9] == 760
assert raw.info['chs'][1]['loc'][9] == 850
assert raw.info['chs'][24]['loc'][9] == 850

# Test source locations
assert_allclose([-8.6765 * 1e-2, 0.0049 * 1e-2, -2.6167 * 1e-2],
Expand Down Expand Up @@ -159,6 +155,7 @@ def test_snirf_basic():
def test_snirf_against_nirx():
"""Test against file snirf was created from."""
raw = read_raw_snirf(sfnirs_homer_103_wShort, preload=True)
_reorder_nirx(raw)
raw_orig = read_raw_nirx(sfnirs_homer_103_wShort_original, preload=True)

# Check annotations are the same
Expand Down Expand Up @@ -225,13 +222,13 @@ def test_snirf_nirsport2():
assert_almost_equal(raw.info['sfreq'], 7.6, decimal=1)

# Test channel naming
assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D1 850',
'S1_D3 760', 'S1_D3 850']
assert raw.info['ch_names'][24:26] == ['S6_D4 760', 'S6_D4 850']
assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D3 760',
'S1_D9 760', 'S1_D16 760']
assert raw.info['ch_names'][24:26] == ['S8_D15 760', 'S8_D20 760']

# Test frequency encoding
assert raw.info['chs'][0]['loc'][9] == 760
assert raw.info['chs'][1]['loc'][9] == 850
assert raw.info['chs'][-1]['loc'][9] == 850

assert sum(short_channels(raw.info)) == 16

Expand All @@ -257,6 +254,7 @@ def test_snirf_nirsport2_w_positions():
"""Test reading SNIRF files with known positions."""
raw = read_raw_snirf(nirx_nirsport2_103_2, preload=True,
optode_frame="mri")
_reorder_nirx(raw)

# Test data import
assert raw._data.shape == (40, 128)
Expand Down
6 changes: 5 additions & 1 deletion mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ def get_score_funcs():
score_funcs.update({n: _make_xy_sfunc(f)
for n, f in xy_arg_dist_funcs
if _get_args(f) == ['u', 'v']})
# In SciPy 1.9+, pearsonr has (u, v, *, alternative='two-sided'), so we
# should just look at the positional_only and positional_or_keyword entries
exclude = ('var_positional', 'var_keyword', 'keyword_only')
score_funcs.update({n: _make_xy_sfunc(f, ndim_output=True)
for n, f in xy_arg_stats_funcs
if _get_args(f) == ['x', 'y']})
if _get_args(f, exclude=exclude) == ['x', 'y']})
assert 'pearsonr' in score_funcs
return score_funcs


Expand Down
5 changes: 3 additions & 2 deletions mne/preprocessing/nirs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from .nirs import (short_channels, source_detector_distances,
_check_channels_ordered, _channel_frequencies,
_fnirs_check_bads, _fnirs_spread_bads, _channel_chromophore,
_validate_nirs_info, _fnirs_optode_names, _optode_position)
_fnirs_spread_bads, _channel_chromophore,
_validate_nirs_info, _fnirs_optode_names, _optode_position,
_reorder_nirx)
from ._optical_density import optical_density
from ._beer_lambert_law import beer_lambert_law
from ._scalp_coupling_index import scalp_coupling_index
Expand Down
20 changes: 10 additions & 10 deletions mne/preprocessing/nirs/_beer_lambert_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from ...io import BaseRaw
from ...io.constants import FIFF
from ...utils import _validate_type, warn
from ..nirs import source_detector_distances, _channel_frequencies,\
_check_channels_ordered, _channel_chromophore
from ..nirs import source_detector_distances, _validate_nirs_info


def beer_lambert_law(raw, ppf=6.):
Expand All @@ -35,8 +34,10 @@ def beer_lambert_law(raw, ppf=6.):
_validate_type(raw, BaseRaw, 'raw')
_validate_type(ppf, 'numeric', 'ppf')
ppf = float(ppf)
freqs = np.unique(_channel_frequencies(raw.info, nominal=True))
picks = _check_channels_ordered(raw.info, freqs)
picks = _validate_nirs_info(raw.info, fnirs='od', which='Beer-lambert')
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array(
[raw.info['chs'][pick]['loc'][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)
distances = source_detector_distances(raw.info)
if (distances == 0).any():
Expand All @@ -49,25 +50,24 @@ def beer_lambert_law(raw, ppf=6.):
'likely due to optode locations being stored in a '
' unit other than meters.')
rename = dict()
for ii in picks[::2]:
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
iEL = linalg.pinv(EL)

raw._data[[ii, ii + 1]] = iEL @ raw._data[[ii, ii + 1]] * 1e-3
raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3

# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO,
hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in enumerate(('hbo', 'hbr')):
ch = raw.info['chs'][ii + ki]
for ki, kind in zip((ii, jj), ('hbo', 'hbr')):
ch = raw.info['chs'][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f'{ch["ch_name"].split(" ")[0]} {kind}'
rename[ch['ch_name']] = new_name
raw.rename_channels(rename)

# Validate the format of data after transformation is valid
chroma = np.unique(_channel_chromophore(raw.info))
_check_channels_ordered(raw.info, chroma)
_validate_nirs_info(raw.info, fnirs='hb')
return raw


Expand Down
Loading

0 comments on commit f435846

Please sign in to comment.