From 48f86eb50b145da8337a5eaea3a1cd955c613303 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 2 Jan 2023 11:38:12 -0500 Subject: [PATCH] MAINT: Sanitize name lists --- doc/changes/latest.inc | 2 +- mne/annotations.py | 42 ++++++++++++++-------------------------- mne/cov.py | 17 ++++++++-------- mne/forward/forward.py | 10 +++------- mne/io/meas_info.py | 25 +++++++++++++++--------- mne/io/proc_history.py | 14 +++++++------- mne/io/proj.py | 10 ++++++---- mne/io/write.py | 20 +++++++++++++++++++ mne/tests/test_evoked.py | 9 +++++++++ 9 files changed, 85 insertions(+), 64 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index b93f79d673f..6c84ed1fec8 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -27,7 +27,7 @@ Enhancements Bugs ~~~~ -- None yet +- Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_) API changes ~~~~~~~~~~~ diff --git a/mne/annotations.py b/mne/annotations.py index b9dc0c2ccc4..78c66ab74ab 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -22,7 +22,8 @@ _check_fname, int_like, _check_option, fill_doc, _on_missing, _is_numeric, _check_dict_keys) -from .io.write import (start_block, end_block, write_float, write_name_list, +from .io.write import (start_block, end_block, write_float, + write_name_list_sanitized, _safe_name_list, write_double, start_file, write_string) from .io.constants import FIFF from .io.open import fiff_open @@ -53,7 +54,7 @@ def _check_o_d_s_c(onset, duration, description, ch_names): if description.ndim != 1: raise ValueError('Description must be a one dimensional array, ' 'got %d.' % (description.ndim,)) - _prep_name_list(description, 'check', 'description') + _safe_name_list(description, 'write', 'description') # ch_names: convert to ndarray of tuples _validate_type(ch_names, (None, tuple, list, np.ndarray), 'ch_names') @@ -986,31 +987,14 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', return onsets, ends -def _prep_name_list(lst, operation, name='description'): - if operation == 'check': - if any(['{COLON}' in val for val in lst]): - raise ValueError( - f'The substring "{{COLON}}" in {name} not supported.') - elif operation == 'write': - # take a list of strings and return a sanitized string - return ':'.join(val.replace(':', '{COLON}') for val in lst) - else: - # take a sanitized string and return a list of strings - assert operation == 'read' - assert isinstance(lst, str) - if not len(lst): - return [] - return [val.replace('{COLON}', ':') for val in lst.split(':')] - - def _write_annotations(fid, annotations): """Write annotations.""" start_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) write_float(fid, FIFF.FIFF_MNE_BASELINE_MIN, annotations.onset) write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, annotations.duration + annotations.onset) - write_name_list(fid, FIFF.FIFF_COMMENT, _prep_name_list( - annotations.description, 'write').split(':')) + write_name_list_sanitized( + fid, FIFF.FIFF_COMMENT, annotations.description, name='description') if annotations.orig_time is not None: write_double(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(annotations.orig_time)) @@ -1024,7 +1008,8 @@ def _write_annotations_csv(fname, annot): annot = annot.to_data_frame() if 'ch_names' in annot: annot['ch_names'] = [ - _prep_name_list(ch, 'write') for ch in annot['ch_names']] + _safe_name_list(ch, 'write', name=f'annot["ch_names"][{ci}') + for ci, ch in enumerate(annot['ch_names'])] annot.to_csv(fname, index=False) @@ -1037,7 +1022,9 @@ def _write_annotations_txt(fname, annot): data = [annot.onset, annot.duration, annot.description] if annot._any_ch_names(): content += ', ch_names' - data.append([_prep_name_list(ch, 'write') for ch in annot.ch_names]) + data.append([ + _safe_name_list(ch, 'write', f'annot.ch_names[{ci}]') + for ci, ch in enumerate(annot.ch_names)]) content += '\n' data = np.array(data, dtype=str).T assert data.ndim == 2 @@ -1178,7 +1165,7 @@ def _read_annotations_csv(fname): description = df['description'].values ch_names = None if 'ch_names' in df.columns: - ch_names = [_prep_name_list(val, 'read') + ch_names = [_safe_name_list(val, 'read', 'annotation channel name') for val in df['ch_names'].values] return Annotations(onset, duration, description, orig_time, ch_names) @@ -1261,8 +1248,9 @@ def _read_annotations_txt(fname): duration = [float(d.decode()) for d in np.atleast_1d(duration)] desc = [str(d.decode()).strip() for d in np.atleast_1d(desc)] if ch_names is not None: - ch_names = [_prep_name_list(ch.decode().strip(), 'read') - for ch in ch_names] + ch_names = [ + _safe_name_list(ch.decode().strip(), 'read', f'ch_names[{ci}]') + for ci, ch in enumerate(ch_names)] return onset, duration, desc, ch_names @@ -1286,7 +1274,7 @@ def _read_annotations_fif(fid, tree): duration = tag.data duration = list() if duration is None else duration - onset elif kind == FIFF.FIFF_COMMENT: - description = _prep_name_list(tag.data, 'read') + description = _safe_name_list(tag.data, 'read', 'description') elif kind == FIFF.FIFF_MEAS_DATE: orig_time = tag.data try: diff --git a/mne/cov.py b/mne/cov.py index 173c77d2e26..20ea2a4aee5 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -23,11 +23,12 @@ _DATA_CH_TYPES_SPLIT) from .io.constants import FIFF -from .io.meas_info import _read_bad_channels, create_info +from .io.meas_info import _read_bad_channels, create_info, _write_bad_channels from .io.tag import find_tag from .io.tree import dir_tree_find -from .io.write import (start_block, end_block, write_int, write_name_list, - write_double, write_float_matrix, write_string) +from .io.write import (start_block, end_block, write_int, write_double, + write_float_matrix, write_string, _safe_name_list, + write_name_list_sanitized) from .defaults import _handle_default from .epochs import Epochs from .event import make_fixed_length_events @@ -1965,7 +1966,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): if tag is None: names = [] else: - names = tag.data.split(':') + names = _safe_name_list(tag.data, 'read', 'names') if len(names) != dim: raise ValueError('Number of names does not match ' 'covariance matrix dimension') @@ -2048,7 +2049,8 @@ def _write_cov(fid, cov): # Channel names if cov['names'] is not None and len(cov['names']) > 0: - write_name_list(fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names']) + write_name_list_sanitized( + fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names'], 'cov["names"]') # Data if cov['diag']: @@ -2070,10 +2072,7 @@ def _write_cov(fid, cov): _write_proj(fid, cov['projs']) # Bad channels - if cov['bads'] is not None and len(cov['bads']) > 0: - start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) - write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, cov['bads']) - end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) + _write_bad_channels(fid, cov['bads'], None) # estimator method if 'method' in cov: diff --git a/mne/forward/forward.py b/mne/forward/forward.py index b73f3542b76..40b546da5d7 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -27,11 +27,10 @@ write_named_matrix) from ..io.meas_info import (_read_bad_channels, write_info, _write_ch_infos, _read_extended_ch_info, _make_ch_names_mapping, - _rename_list) + _write_bad_channels) from ..io.pick import (pick_channels_forward, pick_info, pick_channels, pick_types) -from ..io.write import (write_int, start_block, end_block, - write_coord_trans, write_name_list, +from ..io.write import (write_int, start_block, end_block, write_coord_trans, write_string, start_and_end_file, write_id) from ..io.base import BaseRaw from ..evoked import Evoked, EvokedArray @@ -1030,10 +1029,7 @@ def write_forward_meas_info(fid, info): _write_ch_infos(fid, info['chs'], False, ch_names_mapping) if 'bads' in info and len(info['bads']) > 0: # Bad channels - bads = _rename_list(info['bads'], ch_names_mapping) - start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) - write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads) - end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) + _write_bad_channels(fid, info['bads'], ch_names_mapping) end_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index cf61176bee7..e46fc072772 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -30,8 +30,9 @@ from .ctf_comp import _read_ctf_comp, write_ctf_comp from .write import (start_and_end_file, start_block, end_block, write_string, write_dig_points, write_float, write_int, - write_coord_trans, write_ch_info, write_name_list, - write_julian, write_float_matrix, write_id, DATE_NONE) + write_coord_trans, write_ch_info, + write_julian, write_float_matrix, write_id, DATE_NONE, + _safe_name_list, write_name_list_sanitized) from .proc_history import _read_proc_history, _write_proc_history from ..transforms import (invert_transform, Transform, _coord_frame_name, _ensure_trans, _frame_to_str) @@ -1361,11 +1362,21 @@ def _read_bad_channels(fid, node, ch_names_mapping): for node in nodes: tag = find_tag(fid, node, FIFF.FIFF_MNE_CH_NAME_LIST) if tag is not None and tag.data is not None: - bads = tag.data.split(':') - bads[:] = _rename_list(bads, ch_names_mapping) + bads = _safe_name_list(tag.data, 'read', 'bads') + bads[:] = _rename_list(bads, ch_names_mapping) return bads +def _write_bad_channels(fid, bads, ch_names_mapping): + if bads is not None and len(bads) > 0: + ch_names_mapping = {} if ch_names_mapping is None else ch_names_mapping + bads = _rename_list(bads, ch_names_mapping) + start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) + write_name_list_sanitized( + fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads, 'bads') + end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) + + @verbose def read_meas_info(fid, tree, clean_bads=False, verbose=None): """Read the measurement info. @@ -2059,11 +2070,7 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proj(fid, info['projs'], ch_names_mapping=ch_names_mapping) # Bad channels - if len(info['bads']) > 0: - bads = _rename_list(info['bads'], ch_names_mapping) - start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) - write_name_list(fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads) - end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) + _write_bad_channels(fid, info['bads'], ch_names_mapping=ch_names_mapping) # General if info.get('experimenter') is not None: diff --git a/mne/io/proc_history.py b/mne/io/proc_history.py index f8395d11aa2..1f47f4928b1 100644 --- a/mne/io/proc_history.py +++ b/mne/io/proc_history.py @@ -9,7 +9,8 @@ from .tree import dir_tree_find from .write import (start_block, end_block, write_int, write_float, write_string, write_float_matrix, write_int_matrix, - write_float_sparse, write_id) + write_float_sparse, write_id, write_name_list_sanitized, + _safe_name_list) from .tag import find_tag from .constants import FIFF from ..fixes import _csc_matrix_cast @@ -225,10 +226,8 @@ def _read_maxfilter_record(fid, tree): else: if kind == FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST: tag = read_tag(fid, pos) - chs = tag.data.split(':') - # This list can null chars in the last entry, e.g.: - # [..., u'MEG2642', u'MEG2643', u'MEG2641\x00 ... \x00'] - chs[-1] = chs[-1].split('\x00')[0] + tag.data = tag.data.rstrip('\x00') + chs = _safe_name_list(tag.data, 'read', 'proj_items_chs') sss_ctc['proj_items_chs'] = chs sss_cal_block = dir_tree_find(tree, FIFF.FIFFB_SSS_CAL) # 503 @@ -278,8 +277,9 @@ def _write_maxfilter_record(fid, record): if key in sss_ctc: writer(fid, id_, sss_ctc[key]) if 'proj_items_chs' in sss_ctc: - write_string(fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, - ':'.join(sss_ctc['proj_items_chs'])) + write_name_list_sanitized( + fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, + sss_ctc['proj_items_chs'], 'proj_items_chs') end_block(fid, FIFF.FIFFB_CHANNEL_DECOUPLER) sss_cal = record['sss_cal'] diff --git a/mne/io/proj.py b/mne/io/proj.py index 0d6dcb72487..f3ed778e1bc 100644 --- a/mne/io/proj.py +++ b/mne/io/proj.py @@ -15,8 +15,9 @@ from .pick import pick_types, pick_info, _electrode_types, _ELECTRODE_CH_TYPES from .tag import find_tag, _rename_list from .tree import dir_tree_find -from .write import (write_int, write_float, write_string, write_name_list, - write_float_matrix, end_block, start_block) +from .write import (write_int, write_float, write_string, write_float_matrix, + end_block, start_block, write_name_list_sanitized, + _safe_name_list) from ..defaults import (_INTERPOLATION_DEFAULT, _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT) from ..utils import (logger, verbose, warn, fill_doc, _validate_type, @@ -498,7 +499,7 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST) if tag is not None: - names = tag.data.split(':') + names = _safe_name_list(tag.data, 'read', 'names') else: raise ValueError('Projection item channel list missing') @@ -579,7 +580,8 @@ def _write_proj(fid, projs, *, ch_names_mapping=None): start_block(fid, FIFF.FIFFB_PROJ_ITEM) write_int(fid, FIFF.FIFF_NCHAN, len(proj['data']['col_names'])) names = _rename_list(proj['data']['col_names'], ch_names_mapping) - write_name_list(fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names) + write_name_list_sanitized( + fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names, 'col_names') write_string(fid, FIFF.FIFF_NAME, proj['desc']) write_int(fid, FIFF.FIFF_PROJ_ITEM_KIND, proj['kind']) if proj['kind'] == FIFF.FIFFV_PROJ_ITEM_FIELD: diff --git a/mne/io/write.py b/mne/io/write.py index aa56088d975..80104d98d0b 100644 --- a/mne/io/write.py +++ b/mne/io/write.py @@ -146,6 +146,26 @@ def write_name_list(fid, kind, data): write_string(fid, kind, ':'.join(data)) +def write_name_list_sanitized(fid, kind, lst, name): + write_string(fid, kind, _safe_name_list(lst, 'write', name)) + + +def _safe_name_list(lst, operation, name): + if operation == 'write': + assert isinstance(lst, (list, tuple, np.ndarray)), type(lst) + if any('{COLON}' in val for val in lst): + raise ValueError( + f'The substring "{{COLON}}" in {name} not supported.') + return ':'.join(val.replace(':', '{COLON}') for val in lst) + else: + # take a sanitized string and return a list of strings + assert operation == 'read' + assert lst is None or isinstance(lst, str) + if not lst: # None or empty string + return [] + return [val.replace('{COLON}', ':') for val in lst.split(':')] + + def write_float_matrix(fid, kind, mat): """Write a single-precision floating-point matrix tag.""" FIFFT_MATRIX = 1 << 30 diff --git a/mne/tests/test_evoked.py b/mne/tests/test_evoked.py index 058d759c0a5..d6be4be8e72 100644 --- a/mne/tests/test_evoked.py +++ b/mne/tests/test_evoked.py @@ -284,6 +284,15 @@ def test_io_evoked(tmp_path): aves = read_evokeds(fname_ms, allow_maxshield='yes') assert (all(ave.info['maxshield'] is True for ave in aves)) + # Channel names + with ave.info._unlock(): + ave.info['maxshield'] = False + ave.rename_channels(lambda ch_name: ch_name.replace(' ', ':')) + assert ':' in ave.ch_names[0] + ave.save(fname_ms, overwrite=True) + ave6 = read_evokeds(fname_ms)[0] + assert ave.ch_names == ave6.ch_names + def test_shift_time_evoked(tmp_path): """Test for shifting of time scale."""