diff --git a/mne/io/base.py b/mne/io/base.py index 4c403a438e0..bacdd2699a8 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -16,13 +16,14 @@ import os.path as op import shutil from collections import defaultdict +from dataclasses import dataclass, field import numpy as np from ..filter import _check_resamp_noop from ..event import find_events, concatenate_events from .._fiff.constants import FIFF -from .._fiff.utils import _construct_bids_filename, _check_orig_units +from .._fiff.utils import _make_split_fnames, _check_orig_units from .._fiff.pick import ( pick_types, pick_channels, @@ -1650,11 +1651,11 @@ def save( endings_err = (".fif", ".fif.gz") # convert to str, check for overwrite a few lines later - fname = str(_check_fname(fname, overwrite=True, verbose="error")) + fname = _check_fname(fname, overwrite=True, verbose="error") check_fname(fname, "raw", endings, endings_err=endings_err) split_size = _get_split_size(split_size) - if not self.preload and fname in self._filenames: + if not self.preload and str(fname) in self._filenames: raise ValueError( "You cannot save data to the same file." " Please use a different filename." @@ -1667,17 +1668,6 @@ def save( "command-line MNE tools will not work." ) - type_dict = dict( - short=FIFF.FIFFT_DAU_PACK16, - int=FIFF.FIFFT_INT, - single=FIFF.FIFFT_FLOAT, - double=FIFF.FIFFT_DOUBLE, - ) - _check_option("fmt", fmt, type_dict.keys()) - reset_dict = dict(short=False, int=False, single=True, double=True) - reset_range = reset_dict[fmt] - data_type = type_dict[fmt] - data_test = self[0, 0][0] if fmt == "short" and np.iscomplexobj(data_test): raise ValueError( @@ -1685,7 +1675,7 @@ def save( ) # check for file existence and expand `~` if present - fname = str(_check_fname(fname=fname, overwrite=overwrite, verbose="error")) + fname = _check_fname(fname=fname, overwrite=overwrite, verbose="error") if proj: info = deepcopy(self.info) @@ -1707,25 +1697,9 @@ def save( _validate_type(split_naming, str, "split_naming") _check_option("split_naming", split_naming, ("neuromag", "bids")) - _write_raw( - fname, - self, - info, - picks, - fmt, - data_type, - reset_range, - start, - stop, - buffer_size, - projector, - drop_small_buffer, - split_size, - split_naming, - 0, - None, - overwrite, - ) + cfg = _RawFidWriterCfg(buffer_size, split_size, drop_small_buffer, fmt) + raw_fid_writer = _RawFidWriter(self, info, picks, projector, start, stop, cfg) + _write_raw(raw_fid_writer, fname, split_naming, overwrite) @verbose def export( @@ -2562,98 +2536,49 @@ def set_annotations(self, annotations): ############################################################################### # Writing -def _write_raw( - fname, - raw, - info, - picks, - fmt, - data_type, - reset_range, - start, - stop, - buffer_size, - projector, - drop_small_buffer, - split_size, - split_naming, - part_idx, - prev_fname, - overwrite, -): - """Write raw file with splitting.""" - # we've done something wrong if we hit this - n_times_max = len(raw.times) - if start >= stop or stop > n_times_max: - raise RuntimeError( - "Cannot write raw file with no data: %s -> %s " - "(max: %s) requested" % (start, stop, n_times_max) - ) - # Expand `~` if present - fname = str(_check_fname(fname=fname, overwrite=overwrite)) +# Assume we never hit more than 100 splits, like for epochs +MAX_N_SPLITS = 100 + - base, ext = op.splitext(fname) - if part_idx > 0: - if split_naming == "neuromag": - # insert index in filename - use_fname = "%s-%d%s" % (base, part_idx, ext) +def _write_raw(raw_fid_writer, fpath, split_naming, overwrite): + """Write raw file with splitting.""" + dir_path = fpath.parent + # We have to create one extra filename here to make the for loop below happy, + # but it will raise an error if it actually gets used + split_fnames = _make_split_fnames( + fpath.name, n_splits=MAX_N_SPLITS + 1, split_naming=split_naming + ) + is_next_split, prev_fname = True, None + for part_idx in range(0, MAX_N_SPLITS): + if not is_next_split: + break + bids_special_behavior = part_idx == 0 and split_naming == "bids" + if bids_special_behavior: + reserved_fname = dir_path / split_fnames[0] + logger.info(f"Reserving possible split file {reserved_fname.name}") + _check_fname(reserved_fname, overwrite) + reserved_ctx = _ReservedFilename(reserved_fname) + use_fpath = fpath else: - assert split_naming == "bids" - use_fname = _construct_bids_filename(base, ext, part_idx) - else: - use_fname = fname - # check for file existence - _check_fname(use_fname, overwrite) - # reserve our BIDS split fname in case we need to split - if split_naming == "bids" and part_idx == 0: - # reserve our possible split name - reserved_fname = _construct_bids_filename(base, ext, part_idx) - logger.info(f"Reserving possible split file {op.basename(reserved_fname)}") - _check_fname(reserved_fname, overwrite) - ctx = _ReservedFilename(reserved_fname) + reserved_ctx = nullcontext() + use_fpath = dir_path / split_fnames[part_idx] + next_fname = split_fnames[part_idx + 1] + _check_fname(use_fpath, overwrite) + + logger.info(f"Writing {use_fpath}") + with start_and_end_file(use_fpath) as fid, reserved_ctx: + is_next_split = raw_fid_writer.write(fid, part_idx, prev_fname, next_fname) + logger.info(f"Closing {use_fpath}") + if bids_special_behavior and is_next_split: + logger.info(f"Renaming BIDS split file {fpath.name}") + prev_fname = dir_path / split_fnames[0] + shutil.move(use_fpath, prev_fname) + prev_fname = use_fpath else: - reserved_fname = use_fname - ctx = nullcontext() - logger.info("Writing %s" % use_fname) - - picks = _picks_to_idx(info, picks, "all", ()) - with start_and_end_file(use_fname) as fid: - cals = _start_writing_raw( - fid, info, picks, data_type, reset_range, raw.annotations - ) - with ctx: - final_fname = _write_raw_fid( - raw, - info, - picks, - fid, - cals, - part_idx, - start, - stop, - buffer_size, - prev_fname, - split_size, - use_fname, - projector, - drop_small_buffer, - fmt, - fname, - reserved_fname, - data_type, - reset_range, - split_naming, - overwrite=True, # we've started writing already above - ) - if final_fname != use_fname: - assert split_naming == "bids" - logger.info(f"Renaming BIDS split file {op.basename(final_fname)}") - ctx.remove = False - shutil.move(use_fname, final_fname) - if part_idx == 0: - logger.info("[done]") - return final_fname, part_idx + raise RuntimeError(f"Exceeded maximum number of splits ({MAX_N_SPLITS}).") + + logger.info("[done]") class _ReservedFilename: @@ -2672,29 +2597,104 @@ def __exit__(self, exc_type, exc_value, traceback): os.remove(self.fname) -def _write_raw_fid( +@dataclass(frozen=True) +class _RawFidWriterCfg: + buffer_size: int + split_size: int + drop_small_buffer: bool + fmt: str + reset_range: bool = field(init=False) + data_type: int = field(init=False) + + def __post_init__(self): + type_dict = dict( + short=FIFF.FIFFT_DAU_PACK16, + int=FIFF.FIFFT_INT, + single=FIFF.FIFFT_FLOAT, + double=FIFF.FIFFT_DOUBLE, + ) + _check_option("fmt", self.fmt, type_dict.keys()) + reset_dict = dict(short=False, int=False, single=True, double=True) + object.__setattr__(self, "reset_range", reset_dict[self.fmt]) + object.__setattr__(self, "data_type", type_dict[self.fmt]) + + +class _RawFidWriter: + def __init__(self, raw, info, picks, projector, start, stop, cfg): + self.raw = raw + self.picks = _picks_to_idx(info, picks, "all", ()) + self.info = pick_info(info, sel=self.picks, copy=True) + for k in range(self.info["nchan"]): + # Scan numbers may have been messed up + self.info["chs"][k]["scanno"] = k + 1 # scanno starts at 1 in FIF format + if cfg.reset_range: + self.info["chs"][k]["range"] = 1.0 + self.projector = projector + # self.start is the only mutable attribute in this design! + self.start, self.stop = start, stop + self.cfg = cfg + + def write(self, fid, part_idx, prev_fname, next_fname): + self._check_start_stop_within_bounds() + start_block(fid, FIFF.FIFFB_MEAS) + _write_raw_metadata( + fid, + self.info, + self.cfg.data_type, + self.cfg.reset_range, + self.raw.annotations, + ) + self.start = _write_raw_data( + self.raw, + self.info, + self.picks, + fid, + part_idx, + self.start, + self.stop, + self.cfg.buffer_size, + prev_fname, + self.cfg.split_size, + next_fname, + self.projector, + self.cfg.drop_small_buffer, + self.cfg.fmt, + ) + end_block(fid, FIFF.FIFFB_MEAS) + is_next_split = self.start < self.stop + return is_next_split + + def _check_start_stop_within_bounds(self): + # we've done something wrong if we hit this + n_times_max = len(self.raw.times) + error_msg = ( + "Can't write raw file with no data: {0} -> {1} (max: {2}) requested" + ).format(self.start, self.stop, n_times_max) + if self.start >= self.stop or self.stop > n_times_max: + raise RuntimeError(error_msg) + + +def _write_raw_data( raw, info, picks, fid, - cals, part_idx, start, stop, buffer_size, prev_fname, split_size, - use_fname, + next_fname, projector, drop_small_buffer, fmt, - fname, - reserved_fname, - data_type, - reset_range, - split_naming, - overwrite, ): + # Start the raw data + data_kind = "IAS_" if info.get("maxshield", False) else "" + data_kind = getattr(FIFF, f"FIFFB_{data_kind}RAW_DATA") + start_block(fid, data_kind) + first_samp = raw.first_samp + start if first_samp != 0: write_int(fid, FIFF.FIFF_FIRST_SAMPLE, first_samp) @@ -2736,8 +2736,10 @@ def _write_raw_fid( "output buffer_size, will be written as zeroes." ) + cals = [ch["cal"] * ch["range"] for ch in info["chs"]] + # Write the blocks n_current_skip = 0 - final_fname = use_fname + new_start = start for first, last in zip(firsts, lasts): if do_skips: if ((first >= sk_onsets) & (last <= sk_ends)).any(): @@ -2784,6 +2786,7 @@ def _write_raw_fid( ) ) + new_start = last # Split files if necessary, leave some space for next file info # make sure we check to make sure we actually *need* another buffer # with the "and" check @@ -2791,48 +2794,23 @@ def _write_raw_fid( pos >= split_size - this_buff_size_bytes - _NEXT_FILE_BUFFER and first + buffer_size < stop ): - final_fname = reserved_fname - next_fname, next_idx = _write_raw( - fname, - raw, - info, - picks, - fmt, - data_type, - reset_range, - first + buffer_size, - stop, - buffer_size, - projector, - drop_small_buffer, - split_size, - split_naming, - part_idx + 1, - final_fname, - overwrite, - ) - start_block(fid, FIFF.FIFFB_REF) write_int(fid, FIFF.FIFF_REF_ROLE, FIFF.FIFFV_ROLE_NEXT_FILE) write_string(fid, FIFF.FIFF_REF_FILE_NAME, op.basename(next_fname)) if info["meas_id"] is not None: write_id(fid, FIFF.FIFF_REF_FILE_ID, info["meas_id"]) - write_int(fid, FIFF.FIFF_REF_FILE_NUM, next_idx) + write_int(fid, FIFF.FIFF_REF_FILE_NUM, part_idx + 1) end_block(fid, FIFF.FIFFB_REF) + break pos_prev = pos - logger.info("Closing %s" % use_fname) - if info.get("maxshield", False): - end_block(fid, FIFF.FIFFB_IAS_RAW_DATA) - else: - end_block(fid, FIFF.FIFFB_RAW_DATA) - end_block(fid, FIFF.FIFFB_MEAS) - return final_fname + end_block(fid, data_kind) + return new_start @fill_doc -def _start_writing_raw(fid, info, sel, data_type, reset_range, annotations): +def _write_raw_metadata(fid, info, data_type, reset_range, annotations): """Start write raw data in file. Parameters @@ -2840,9 +2818,6 @@ def _start_writing_raw(fid, info, sel, data_type, reset_range, annotations): fid : file The created file. %(info_not_none)s - sel : array of int | None - Indices of channels to include. If None, all channels - are included. data_type : int The data_type in case it is necessary. Should be 4 (FIFFT_FLOAT), 5 (FIFFT_DOUBLE), 16 (FIFFT_DAU_PACK16), or 3 (FIFFT_INT) for raw data. @@ -2851,36 +2826,14 @@ def _start_writing_raw(fid, info, sel, data_type, reset_range, annotations): annotations : instance of Annotations The annotations to write. - Returns - ------- - fid : file - The file descriptor. - cals : list - calibration factors. """ - # - # Measurement info - # - info = pick_info(info, sel) - # # Create the file and save the essentials # - start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) if info["meas_id"] is not None: write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) - cals = [] - for k in range(info["nchan"]): - # - # Scan numbers may have been messed up - # - info["chs"][k]["scanno"] = k + 1 # scanno starts at 1 in FIF format - if reset_range is True: - info["chs"][k]["range"] = 1.0 - cals.append(info["chs"][k]["cal"] * info["chs"][k]["range"]) - write_meas_info(fid, info, data_type=data_type, reset_range=reset_range) # @@ -2889,16 +2842,6 @@ def _start_writing_raw(fid, info, sel, data_type, reset_range, annotations): if len(annotations) > 0: # don't save empty annot _write_annotations(fid, annotations) - # - # Start the raw data - # - if info.get("maxshield", False): - start_block(fid, FIFF.FIFFB_IAS_RAW_DATA) - else: - start_block(fid, FIFF.FIFFB_RAW_DATA) - - return cals - def _write_raw_buffer(fid, buf, cals, fmt): """Write raw buffer. diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index b53b94ce077..ebc4645f705 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -551,7 +551,7 @@ def test_split_files(tmp_path, mod, monkeypatch): annot = Annotations(np.arange(20), np.ones((20,)), "test") raw_1.set_annotations(annot) - split_fname = tmp_path / "split_raw.fif" + split_fname = tmp_path / f"split_{mod}.fif" raw_1.save(split_fname, buffer_size_sec=1.0, split_size="10MB") raw_2 = read_raw_fif(split_fname) assert_allclose(raw_2.buffer_size_sec, 1.0, atol=1e-2) # samp rate @@ -641,12 +641,37 @@ def test_split_files(tmp_path, mod, monkeypatch): raw_crop.save(tmp_path / "test.fif", split_naming="bids", verbose="error") # reserved file is deleted - fname = tmp_path / "test_raw.fif" - monkeypatch.setattr(base, "_write_raw_fid", _err) - with pytest.raises(RuntimeError, match="Killed mid-write"): - raw_1.save(fname, split_size="10MB", split_naming="bids") + fname = tmp_path / f"test_{mod}.fif" + with monkeypatch.context() as m: + m.setattr(base, "_write_raw_data", _err) + with pytest.raises(RuntimeError, match="Killed mid-write"): + raw_1.save(fname, split_size="10MB", split_naming="bids") + assert fname.is_file() + assert not (tmp_path / "test_split-01_{mod}.fif").is_file() + + # MAX_N_SPLITS exceeeded + raw = RawArray(np.zeros((1, 2000000)), create_info(1, 1000.0, "eeg")) + fname.unlink() + kwargs = dict(split_size="2MB", overwrite=True, verbose=True) + with monkeypatch.context() as m: + m.setattr(base, "MAX_N_SPLITS", 2) + with pytest.raises(RuntimeError, match="Exceeded maximum number of splits"): + raw.save(fname, split_naming="bids", **kwargs) + fname_1, fname_2, fname_3 = [ + (tmp_path / f"test_split-{ii:02d}_{mod}.fif") for ii in range(1, 4) + ] + assert not fname.is_file() + assert fname_1.is_file() + assert fname_2.is_file() + assert not fname_3.is_file() + with monkeypatch.context() as m: + m.setattr(base, "MAX_N_SPLITS", 2) + with pytest.raises(RuntimeError, match="Exceeded maximum number of splits"): + raw.save(fname, split_naming="neuromag", **kwargs) + fname_2, fname_3 = [(tmp_path / f"test_{mod}-{ii}.fif") for ii in range(1, 3)] assert fname.is_file() - assert not (tmp_path / "test_split-01_raw.fif").is_file() + assert fname_2.is_file() + assert not fname_3.is_file() def _err(*args, **kwargs): diff --git a/mne/io/nsx/tests/test_nsx.py b/mne/io/nsx/tests/test_nsx.py index 7206d1d40a6..8aa22677552 100644 --- a/mne/io/nsx/tests/test_nsx.py +++ b/mne/io/nsx/tests/test_nsx.py @@ -82,7 +82,7 @@ def test_nsx_ver_31(): assert raw.annotations[0]["onset"] * raw.info["sfreq"] == 101 assert raw.annotations[0]["duration"] * raw.info["sfreq"] == 49 - # Ignore following RuntimeWarning in mne/io/base.py in _write_raw_fid + # Ignore following RuntimeWarning in mne/io/base.py in _write_raw_data # "Acquisition skips detected but did not fit evenly into output" # "buffer_size, will be written as zeroes." with pytest.warns(RuntimeWarning, match="skips detected"):