Skip to content

Commit

Permalink
MAINT: Sanitize name lists
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jan 2, 2023
1 parent 9b1e7dd commit 48f86eb
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 64 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enhancements

Bugs
~~~~
- None yet
- Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_)

API changes
~~~~~~~~~~~
Expand Down
42 changes: 15 additions & 27 deletions mne/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
_check_fname, int_like, _check_option, fill_doc,
_on_missing, _is_numeric, _check_dict_keys)

from .io.write import (start_block, end_block, write_float, write_name_list,
from .io.write import (start_block, end_block, write_float,
write_name_list_sanitized, _safe_name_list,
write_double, start_file, write_string)
from .io.constants import FIFF
from .io.open import fiff_open
Expand Down Expand Up @@ -53,7 +54,7 @@ def _check_o_d_s_c(onset, duration, description, ch_names):
if description.ndim != 1:
raise ValueError('Description must be a one dimensional array, '
'got %d.' % (description.ndim,))
_prep_name_list(description, 'check', 'description')
_safe_name_list(description, 'write', 'description')

# ch_names: convert to ndarray of tuples
_validate_type(ch_names, (None, tuple, list, np.ndarray), 'ch_names')
Expand Down Expand Up @@ -986,31 +987,14 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation',
return onsets, ends


def _prep_name_list(lst, operation, name='description'):
if operation == 'check':
if any(['{COLON}' in val for val in lst]):
raise ValueError(
f'The substring "{{COLON}}" in {name} not supported.')
elif operation == 'write':
# take a list of strings and return a sanitized string
return ':'.join(val.replace(':', '{COLON}') for val in lst)
else:
# take a sanitized string and return a list of strings
assert operation == 'read'
assert isinstance(lst, str)
if not len(lst):
return []
return [val.replace('{COLON}', ':') for val in lst.split(':')]


def _write_annotations(fid, annotations):
"""Write annotations."""
start_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS)
write_float(fid, FIFF.FIFF_MNE_BASELINE_MIN, annotations.onset)
write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX,
annotations.duration + annotations.onset)
write_name_list(fid, FIFF.FIFF_COMMENT, _prep_name_list(
annotations.description, 'write').split(':'))
write_name_list_sanitized(
fid, FIFF.FIFF_COMMENT, annotations.description, name='description')
if annotations.orig_time is not None:
write_double(fid, FIFF.FIFF_MEAS_DATE,
_dt_to_stamp(annotations.orig_time))
Expand All @@ -1024,7 +1008,8 @@ def _write_annotations_csv(fname, annot):
annot = annot.to_data_frame()
if 'ch_names' in annot:
annot['ch_names'] = [
_prep_name_list(ch, 'write') for ch in annot['ch_names']]
_safe_name_list(ch, 'write', name=f'annot["ch_names"][{ci}')
for ci, ch in enumerate(annot['ch_names'])]
annot.to_csv(fname, index=False)


Expand All @@ -1037,7 +1022,9 @@ def _write_annotations_txt(fname, annot):
data = [annot.onset, annot.duration, annot.description]
if annot._any_ch_names():
content += ', ch_names'
data.append([_prep_name_list(ch, 'write') for ch in annot.ch_names])
data.append([
_safe_name_list(ch, 'write', f'annot.ch_names[{ci}]')
for ci, ch in enumerate(annot.ch_names)])
content += '\n'
data = np.array(data, dtype=str).T
assert data.ndim == 2
Expand Down Expand Up @@ -1178,7 +1165,7 @@ def _read_annotations_csv(fname):
description = df['description'].values
ch_names = None
if 'ch_names' in df.columns:
ch_names = [_prep_name_list(val, 'read')
ch_names = [_safe_name_list(val, 'read', 'annotation channel name')
for val in df['ch_names'].values]
return Annotations(onset, duration, description, orig_time, ch_names)

Expand Down Expand Up @@ -1261,8 +1248,9 @@ def _read_annotations_txt(fname):
duration = [float(d.decode()) for d in np.atleast_1d(duration)]
desc = [str(d.decode()).strip() for d in np.atleast_1d(desc)]
if ch_names is not None:
ch_names = [_prep_name_list(ch.decode().strip(), 'read')
for ch in ch_names]
ch_names = [
_safe_name_list(ch.decode().strip(), 'read', f'ch_names[{ci}]')
for ci, ch in enumerate(ch_names)]
return onset, duration, desc, ch_names


Expand All @@ -1286,7 +1274,7 @@ def _read_annotations_fif(fid, tree):
duration = tag.data
duration = list() if duration is None else duration - onset
elif kind == FIFF.FIFF_COMMENT:
description = _prep_name_list(tag.data, 'read')
description = _safe_name_list(tag.data, 'read', 'description')
elif kind == FIFF.FIFF_MEAS_DATE:
orig_time = tag.data
try:
Expand Down
17 changes: 8 additions & 9 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
_DATA_CH_TYPES_SPLIT)

from .io.constants import FIFF
from .io.meas_info import _read_bad_channels, create_info
from .io.meas_info import _read_bad_channels, create_info, _write_bad_channels
from .io.tag import find_tag
from .io.tree import dir_tree_find
from .io.write import (start_block, end_block, write_int, write_name_list,
write_double, write_float_matrix, write_string)
from .io.write import (start_block, end_block, write_int, write_double,
write_float_matrix, write_string, _safe_name_list,
write_name_list_sanitized)
from .defaults import _handle_default
from .epochs import Epochs
from .event import make_fixed_length_events
Expand Down Expand Up @@ -1965,7 +1966,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None):
if tag is None:
names = []
else:
names = tag.data.split(':')
names = _safe_name_list(tag.data, 'read', 'names')
if len(names) != dim:
raise ValueError('Number of names does not match '
'covariance matrix dimension')
Expand Down Expand Up @@ -2048,7 +2049,8 @@ def _write_cov(fid, cov):

# Channel names
if cov['names'] is not None and len(cov['names']) > 0:
write_name_list(fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names'])
write_name_list_sanitized(
fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names'], 'cov["names"]')

# Data
if cov['diag']:
Expand All @@ -2070,10 +2072,7 @@ def _write_cov(fid, cov):
_write_proj(fid, cov['projs'])

# Bad channels
if cov['bads'] is not None and len(cov['bads']) > 0:
start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, cov['bads'])
end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
_write_bad_channels(fid, cov['bads'], None)

# estimator method
if 'method' in cov:
Expand Down
10 changes: 3 additions & 7 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
write_named_matrix)
from ..io.meas_info import (_read_bad_channels, write_info, _write_ch_infos,
_read_extended_ch_info, _make_ch_names_mapping,
_rename_list)
_write_bad_channels)
from ..io.pick import (pick_channels_forward, pick_info, pick_channels,
pick_types)
from ..io.write import (write_int, start_block, end_block,
write_coord_trans, write_name_list,
from ..io.write import (write_int, start_block, end_block, write_coord_trans,
write_string, start_and_end_file, write_id)
from ..io.base import BaseRaw
from ..evoked import Evoked, EvokedArray
Expand Down Expand Up @@ -1030,10 +1029,7 @@ def write_forward_meas_info(fid, info):
_write_ch_infos(fid, info['chs'], False, ch_names_mapping)
if 'bads' in info and len(info['bads']) > 0:
# Bad channels
bads = _rename_list(info['bads'], ch_names_mapping)
start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads)
end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
_write_bad_channels(fid, info['bads'], ch_names_mapping)

end_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE)

Expand Down
25 changes: 16 additions & 9 deletions mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from .ctf_comp import _read_ctf_comp, write_ctf_comp
from .write import (start_and_end_file, start_block, end_block,
write_string, write_dig_points, write_float, write_int,
write_coord_trans, write_ch_info, write_name_list,
write_julian, write_float_matrix, write_id, DATE_NONE)
write_coord_trans, write_ch_info,
write_julian, write_float_matrix, write_id, DATE_NONE,
_safe_name_list, write_name_list_sanitized)
from .proc_history import _read_proc_history, _write_proc_history
from ..transforms import (invert_transform, Transform, _coord_frame_name,
_ensure_trans, _frame_to_str)
Expand Down Expand Up @@ -1361,11 +1362,21 @@ def _read_bad_channels(fid, node, ch_names_mapping):
for node in nodes:
tag = find_tag(fid, node, FIFF.FIFF_MNE_CH_NAME_LIST)
if tag is not None and tag.data is not None:
bads = tag.data.split(':')
bads[:] = _rename_list(bads, ch_names_mapping)
bads = _safe_name_list(tag.data, 'read', 'bads')
bads[:] = _rename_list(bads, ch_names_mapping)
return bads


def _write_bad_channels(fid, bads, ch_names_mapping):
if bads is not None and len(bads) > 0:
ch_names_mapping = {} if ch_names_mapping is None else ch_names_mapping
bads = _rename_list(bads, ch_names_mapping)
start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
write_name_list_sanitized(
fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads, 'bads')
end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)


@verbose
def read_meas_info(fid, tree, clean_bads=False, verbose=None):
"""Read the measurement info.
Expand Down Expand Up @@ -2059,11 +2070,7 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
_write_proj(fid, info['projs'], ch_names_mapping=ch_names_mapping)

# Bad channels
if len(info['bads']) > 0:
bads = _rename_list(info['bads'], ch_names_mapping)
start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads)
end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS)
_write_bad_channels(fid, info['bads'], ch_names_mapping=ch_names_mapping)

# General
if info.get('experimenter') is not None:
Expand Down
14 changes: 7 additions & 7 deletions mne/io/proc_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .tree import dir_tree_find
from .write import (start_block, end_block, write_int, write_float,
write_string, write_float_matrix, write_int_matrix,
write_float_sparse, write_id)
write_float_sparse, write_id, write_name_list_sanitized,
_safe_name_list)
from .tag import find_tag
from .constants import FIFF
from ..fixes import _csc_matrix_cast
Expand Down Expand Up @@ -225,10 +226,8 @@ def _read_maxfilter_record(fid, tree):
else:
if kind == FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST:
tag = read_tag(fid, pos)
chs = tag.data.split(':')
# This list can null chars in the last entry, e.g.:
# [..., u'MEG2642', u'MEG2643', u'MEG2641\x00 ... \x00']
chs[-1] = chs[-1].split('\x00')[0]
tag.data = tag.data.rstrip('\x00')
chs = _safe_name_list(tag.data, 'read', 'proj_items_chs')
sss_ctc['proj_items_chs'] = chs

sss_cal_block = dir_tree_find(tree, FIFF.FIFFB_SSS_CAL) # 503
Expand Down Expand Up @@ -278,8 +277,9 @@ def _write_maxfilter_record(fid, record):
if key in sss_ctc:
writer(fid, id_, sss_ctc[key])
if 'proj_items_chs' in sss_ctc:
write_string(fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST,
':'.join(sss_ctc['proj_items_chs']))
write_name_list_sanitized(
fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST,
sss_ctc['proj_items_chs'], 'proj_items_chs')
end_block(fid, FIFF.FIFFB_CHANNEL_DECOUPLER)

sss_cal = record['sss_cal']
Expand Down
10 changes: 6 additions & 4 deletions mne/io/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from .pick import pick_types, pick_info, _electrode_types, _ELECTRODE_CH_TYPES
from .tag import find_tag, _rename_list
from .tree import dir_tree_find
from .write import (write_int, write_float, write_string, write_name_list,
write_float_matrix, end_block, start_block)
from .write import (write_int, write_float, write_string, write_float_matrix,
end_block, start_block, write_name_list_sanitized,
_safe_name_list)
from ..defaults import (_INTERPOLATION_DEFAULT, _BORDER_DEFAULT,
_EXTRAPOLATE_DEFAULT)
from ..utils import (logger, verbose, warn, fill_doc, _validate_type,
Expand Down Expand Up @@ -498,7 +499,7 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None):

tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST)
if tag is not None:
names = tag.data.split(':')
names = _safe_name_list(tag.data, 'read', 'names')
else:
raise ValueError('Projection item channel list missing')

Expand Down Expand Up @@ -579,7 +580,8 @@ def _write_proj(fid, projs, *, ch_names_mapping=None):
start_block(fid, FIFF.FIFFB_PROJ_ITEM)
write_int(fid, FIFF.FIFF_NCHAN, len(proj['data']['col_names']))
names = _rename_list(proj['data']['col_names'], ch_names_mapping)
write_name_list(fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names)
write_name_list_sanitized(
fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names, 'col_names')
write_string(fid, FIFF.FIFF_NAME, proj['desc'])
write_int(fid, FIFF.FIFF_PROJ_ITEM_KIND, proj['kind'])
if proj['kind'] == FIFF.FIFFV_PROJ_ITEM_FIELD:
Expand Down
20 changes: 20 additions & 0 deletions mne/io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ def write_name_list(fid, kind, data):
write_string(fid, kind, ':'.join(data))


def write_name_list_sanitized(fid, kind, lst, name):
write_string(fid, kind, _safe_name_list(lst, 'write', name))


def _safe_name_list(lst, operation, name):
if operation == 'write':
assert isinstance(lst, (list, tuple, np.ndarray)), type(lst)
if any('{COLON}' in val for val in lst):
raise ValueError(
f'The substring "{{COLON}}" in {name} not supported.')
return ':'.join(val.replace(':', '{COLON}') for val in lst)
else:
# take a sanitized string and return a list of strings
assert operation == 'read'
assert lst is None or isinstance(lst, str)
if not lst: # None or empty string
return []
return [val.replace('{COLON}', ':') for val in lst.split(':')]


def write_float_matrix(fid, kind, mat):
"""Write a single-precision floating-point matrix tag."""
FIFFT_MATRIX = 1 << 30
Expand Down
9 changes: 9 additions & 0 deletions mne/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ def test_io_evoked(tmp_path):
aves = read_evokeds(fname_ms, allow_maxshield='yes')
assert (all(ave.info['maxshield'] is True for ave in aves))

# Channel names
with ave.info._unlock():
ave.info['maxshield'] = False
ave.rename_channels(lambda ch_name: ch_name.replace(' ', ':'))
assert ':' in ave.ch_names[0]
ave.save(fname_ms, overwrite=True)
ave6 = read_evokeds(fname_ms)[0]
assert ave.ch_names == ave6.ch_names


def test_shift_time_evoked(tmp_path):
"""Test for shifting of time scale."""
Expand Down

0 comments on commit 48f86eb

Please sign in to comment.