Skip to content

Commit

Permalink
WIP EpochsSpectrum IO [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Aug 5, 2022
1 parent d387fe3 commit 5f1e333
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
36 changes: 26 additions & 10 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,6 @@ class BaseSpectrum(ContainsMixin, UpdateChannelsMixin):

def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks,
proj, reject_by_annotation, *, n_jobs, verbose, **method_kw):
# triage reading from file
if isinstance(inst, dict):
self._from_file(**inst)
return

# arg checking
self._sfreq = inst.info['sfreq']
if np.isfinite(fmax) and (fmax > self.sfreq / 2):
Expand Down Expand Up @@ -296,6 +291,10 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose):
if method_kw.get('output', '') == 'complex':
self._shape = (
self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:])
# we don't need these anymore, and they make save/load harder
del self._picks
del self._psd_func
del self._time_mask

def _format_units(self, unit, latex, power=True):
"""Format the measurement units nicely."""
Expand All @@ -309,8 +308,8 @@ def _format_units(self, unit, latex, power=True):
pre, post = (r'$\mathrm{', r'}$') if latex else ('', '')
return f'{pre}{unit}{exp}/{denom}{post}'

def _from_file(self, method, data, freqs, dims, data_type, inst_type,
info):
def _from_file(self, method, data, freqs, sfreq, shape, dims, data_type,
inst_type, info):
"""Recreate Spectrum object from hdf5 file."""
from .. import Epochs, Evoked, Info
from ..io import Raw
Expand All @@ -319,6 +318,8 @@ def _from_file(self, method, data, freqs, dims, data_type, inst_type,
self._data = data
self._freqs = freqs
self._dims = dims
self._sfreq = sfreq
self._shape = shape
self.info = Info(**info)
self._data_type = data_type
self.preload = True
Expand Down Expand Up @@ -364,6 +365,10 @@ def method(self):
def sfreq(self):
return self._sfreq

@property
def shape(self):
return self._shape

def copy(self):
"""Return copy of the Spectrum instance.
Expand Down Expand Up @@ -558,6 +563,8 @@ def save(self, fname, *, overwrite=False, verbose=None):
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
out = dict(method=self.method,
data=self.get_data(picks='all', exclude=[]),
sfreq=self.sfreq,
shape=self.shape,
dims=self._dims,
freqs=self.freqs,
inst_type=self._get_instance_type_string(),
Expand Down Expand Up @@ -719,6 +726,11 @@ def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks,
proj, reject_by_annotation, *, n_jobs, verbose, **method_kw):
from ..io import BaseRaw

# triage reading from file
if isinstance(inst, dict):
self._from_file(**inst)
return
# do the basic setup
super().__init__(inst, method, fmin, fmax, tmin, tmax, picks, proj,
reject_by_annotation, n_jobs=n_jobs, verbose=verbose,
**method_kw)
Expand Down Expand Up @@ -815,11 +827,14 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):

def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks,
proj, reject_by_annotation, *, n_jobs, verbose, **method_kw):

# triage reading from file
if isinstance(inst, dict):
self._from_file(**inst)
return
# do the basic setup
super().__init__(inst, method, fmin, fmax, tmin, tmax, picks, proj,
reject_by_annotation, n_jobs=n_jobs, verbose=verbose,
**method_kw)

# get just the data we want
data = inst.get_data(picks=self._picks)[:, :, self._time_mask]
# compute the spectra
Expand Down Expand Up @@ -878,7 +893,8 @@ def read_spectrum(fname):
defaults = dict(method=None, fmin=None, fmax=None, tmin=None, tmax=None,
picks=None, proj=None, reject_by_annotation=None,
n_jobs=None, verbose=None)
return Spectrum(hdf5_dict, **defaults)
Klass = EpochsSpectrum if hdf5_dict['inst_type'] == 'Epochs' else Spectrum
return Klass(hdf5_dict, **defaults)


def _check_ci(ci):
Expand Down
6 changes: 4 additions & 2 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def test_spectrum_params(method, fmin, fmax, tmin, tmax, picks, proj, n_fft,


@requires_h5py
def test_spectrum_io(raw, tmp_path):
@pytest.mark.parametrize('inst', ('raw', 'epochs'))
def test_spectrum_io(inst, tmp_path, request):
"""Test save/load of spectrum objects."""
inst = request.getfixturevalue(inst)
fname = tmp_path / 'spectrum.h5'
orig = raw.compute_psd()
orig = inst.compute_psd()
orig.save(fname)
loaded = read_spectrum(fname)
assert orig == loaded
Expand Down

0 comments on commit 5f1e333

Please sign in to comment.