Skip to content

Commit

Permalink
fix type checking, better variable naming
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Aug 11, 2022
1 parent bcffe8d commit b421d42
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,22 @@ def __eq__(self, other):

def __repr__(self):
"""Build string representation of the Spectrum object."""
inst_type = self._get_instance_type_string()
inst_type_str = self._get_instance_type_string()
# shape & dimension names
dims = ' × '.join(
[f'{dim[0]} {dim[1]}s'
for dim in zip(self._data.shape, self._dims)])
freq_range = f'{self.freqs[0]:0.1f}-{self.freqs[-1]:0.1f} Hz'
return f'<{self._data_type} (from {inst_type}) | {dims}, {freq_range}>'
return (f'<{self._data_type} '
f'(from {inst_type_str}) | {dims}, {freq_range}>')

def _repr_html_(self, caption=None):
"""Build HTML representation of the Spectrum object."""
from ..html_templates import repr_templates_env

inst_type = self._get_instance_type_string()
inst_type_str = self._get_instance_type_string()
t = repr_templates_env.get_template('spectrum.html.jinja')
t = t.render(spectrum=self, inst_type=inst_type,
t = t.render(spectrum=self, inst_type=inst_type_str,
data_type=self._data_type)
return t

Expand Down Expand Up @@ -309,7 +310,7 @@ def _format_units(self, unit, latex, power=True):
return f'{pre}{unit}{exp}/{denom}{post}'

def _from_file(self, method, data, freqs, sfreq, shape, dims, data_type,
inst_type, info, metadata=None, drop_log=None,
inst_type_str, info, metadata=None, drop_log=None,
event_id=None, events=None, selection=None):
"""Recreate Spectrum object from hdf5 file."""
from .. import Epochs, Evoked, Info
Expand All @@ -324,15 +325,15 @@ def _from_file(self, method, data, freqs, sfreq, shape, dims, data_type,
self.info = Info(**info)
self._data_type = data_type
self.preload = True
if inst_type == 'Epochs':
if inst_type_str == 'Epochs':
self._metadata = metadata
self.drop_log = drop_log
self.event_id = event_id
self.events = events
self.selection = selection
# instance type
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked)
self._inst_type = inst_types[inst_type]
self._inst_type = inst_types[inst_type_str]

def _get_instance_type_string(self):
"""Get string representation of the originating instance type."""
Expand All @@ -341,15 +342,15 @@ def _get_instance_type_string(self):

parent_classes = self._inst_type.__bases__
if BaseRaw in parent_classes:
inst_type = 'Raw'
inst_type_str = 'Raw'
elif BaseEpochs in parent_classes:
inst_type = 'Epochs'
elif Evoked in parent_classes:
inst_type = 'Evoked'
inst_type_str = 'Epochs'
elif self._inst_type == Evoked:
inst_type_str = 'Evoked'
else:
raise RuntimeError(
f'Unknown instance type {self._inst_type} in Spectrum')
return inst_type
return inst_type_str

@property
def _detrend_picks(self):
Expand Down Expand Up @@ -575,7 +576,7 @@ def save(self, fname, *, overwrite=False, verbose=None):
shape=self.shape,
dims=self._dims,
freqs=self.freqs,
inst_type=inst_type_str,
inst_type_str=inst_type_str,
data_type=self._data_type,
info=self.info)
if inst_type_str == 'Epochs':
Expand Down Expand Up @@ -907,7 +908,7 @@ def read_spectrum(fname):
defaults = dict(method=None, fmin=None, fmax=None, tmin=None, tmax=None,
picks=None, proj=None, reject_by_annotation=None,
n_jobs=None, verbose=None)
Klass = EpochsSpectrum if hdf5_dict['inst_type'] == 'Epochs' else Spectrum
Klass = EpochsSpectrum if hdf5_dict['inst_type_str'] == 'Epochs' else Spectrum
return Klass(hdf5_dict, **defaults)


Expand Down

0 comments on commit b421d42

Please sign in to comment.