Skip to content

Commit

Permalink
get plot_topomap working [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Jan 22, 2024
1 parent 65b2a44 commit 916ec19
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 10 deletions.
147 changes: 143 additions & 4 deletions mne/time_frequency/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .._fiff.pick import _picks_to_idx, pick_info
from ..baseline import _check_baseline, rescale
from ..channels.channels import UpdateChannelsMixin
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..utils import (
ExtendedTimeMixin,
GetEpochsMixin,
Expand All @@ -25,6 +26,7 @@
_pl,
_time_mask,
check_fname,
copy_function_doc_to_method_doc,
fill_doc,
object_diff,
repr_html,
Expand All @@ -33,6 +35,7 @@
)
from ..utils.check import _check_combine, _validate_type
from ..viz.topo import _imshow_tfr
from ..viz.topomap import plot_tfr_topomap
from ..viz.utils import (
_make_combine_callable,
_set_title_multiple_electrodes,
Expand Down Expand Up @@ -737,15 +740,13 @@ def plot(
# real-valued they should already be power
if np.iscomplexobj(data):
data = (data * data.conj()).real
# dB
if dB:
data = 10 * np.log10(data)
# shape
_axis = self._dims.index("channel")
want_shape = list(self.shape)
want_shape[self._dims.index("channel")] = len(picks) if combine is None else 1
want_shape[_axis] = len(picks) if combine is None else 1
want_shape = tuple(want_shape)
# this: ↓↓↓↓ makes np.squeeze a no-op; the rest lets singleton EpochsTFR work
_axis = None if combine is None else self._dims.index("channel")
# combine
combine = _make_combine_callable(
combine, axis=_axis, valid=("mean", "rms"), keepdims=True
Expand Down Expand Up @@ -847,6 +848,77 @@ def plot(
plt_show(show)
return figs

def plot_joint(self):
"""Plot TFRs as a two-dimensional image with topomaps."""
pass

def plot_topo(self):
"""Plot TFRs as a two-dimensional image with topomaps."""
pass

@copy_function_doc_to_method_doc(plot_tfr_topomap)
def plot_topomap(
self,
tmin=None,
tmax=None,
fmin=0.0,
fmax=np.inf,
*,
ch_type=None,
baseline=None,
mode="mean",
sensors=True,
show_names=False,
mask=None,
mask_params=None,
contours=6,
outlines="head",
sphere=None,
image_interp=_INTERPOLATION_DEFAULT,
extrapolate=_EXTRAPOLATE_DEFAULT,
border=_BORDER_DEFAULT,
res=64,
size=2,
cmap=None,
vlim=(None, None),
cnorm=None,
colorbar=True,
cbar_fmt="%1.1e",
units=None,
axes=None,
show=True,
):
return plot_tfr_topomap(
self,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
ch_type=ch_type,
baseline=baseline,
mode=mode,
sensors=sensors,
show_names=show_names,
mask=mask,
mask_params=mask_params,
contours=contours,
outlines=outlines,
sphere=sphere,
image_interp=image_interp,
extrapolate=extrapolate,
border=border,
res=res,
size=size,
cmap=cmap,
vlim=vlim,
cnorm=cnorm,
colorbar=colorbar,
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
show=show,
)

@verbose
def save(self, fname, *, overwrite=False, verbose=None):
"""Save time-frequency data to disk (in HDF5 format).
Expand Down Expand Up @@ -1295,6 +1367,73 @@ def iter_evoked(self, copy=False):
state["nave"] = 1
yield AverageTFR(state, method=None, freqs=None, comment=str(event_id))

@copy_function_doc_to_method_doc(plot_tfr_topomap)
def plot_topomap(
self,
tmin=None,
tmax=None,
fmin=0.0,
fmax=np.inf,
*,
ch_type=None,
baseline=None,
mode="mean",
sensors=True,
show_names=False,
mask=None,
mask_params=None,
contours=6,
outlines="head",
sphere=None,
image_interp=_INTERPOLATION_DEFAULT,
extrapolate=_EXTRAPOLATE_DEFAULT,
border=_BORDER_DEFAULT,
res=64,
size=2,
cmap=None,
vlim=(None, None),
cnorm=None,
colorbar=True,
cbar_fmt="%1.1e",
units=None,
axes=None,
show=True,
):
if self.shape[0] > 1:
raise NotImplementedError(
"Cannot plot topomap for multiple EpochsTFR epochs; please subselect a"
"single epoch before plotting."
)
return list(self.iter_evoked())[0].plot_topomap(
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
ch_type=ch_type,
baseline=baseline,
mode=mode,
sensors=sensors,
show_names=show_names,
mask=mask,
mask_params=mask_params,
contours=contours,
outlines=outlines,
sphere=sphere,
image_interp=image_interp,
extrapolate=extrapolate,
border=border,
res=res,
size=size,
cmap=cmap,
vlim=vlim,
cnorm=cnorm,
colorbar=colorbar,
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
show=show,
)


class AverageTFR(BaseTFR):
"""Data object for spectrotemporal representations of averaged data.
Expand Down
6 changes: 1 addition & 5 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
_is_numeric,
check_fname,
)
from ..utils.misc import _pl
from ..utils.misc import _identity_function, _pl
from ..utils.spectrum import _split_psd_kwargs
from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo
from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap
Expand All @@ -59,10 +59,6 @@
from .psd import _check_nfft, psd_array_welch


def _identity_function(x):
return x


class SpectrumMixin:
"""Mixin providing spectral plotting methods to sensor-space containers."""

Expand Down
4 changes: 4 additions & 0 deletions mne/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from .check import _check_option, _validate_type


def _identity_function(x):
return x


# TODO: no longer needed when py3.9 is minimum supported version
def _empty_hash(kind="md5"):
func = getattr(hashlib, kind)
Expand Down
7 changes: 6 additions & 1 deletion mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .._fiff.proj import Projection, setup_proj
from ..defaults import _handle_default
from ..fixes import _median_complex
from ..rank import compute_rank
from ..transforms import apply_trans
from ..utils import (
Expand All @@ -65,6 +66,7 @@
verbose,
warn,
)
from ..utils.misc import _identity_function
from .ui_events import ColormapRange, publish, subscribe

_channel_type_prettyprint = {
Expand Down Expand Up @@ -2352,13 +2354,16 @@ def _make_combine_callable(
"""
kwargs = dict(axis=axis, keepdims=keepdims)
if combine is None:
combine = partial(np.squeeze, axis=axis)
combine = _identity_function
elif isinstance(combine, str):
combine_dict = {
key: partial(getattr(np, key), **kwargs)
for key in valid
if getattr(np, key, None) is not None
}
# marginal median that is safe for complex values:
if "median" in valid:
combine_dict["median"] = partial(_median_complex, axis=axis)
# RMS and GFP are computed the same way
for key in ("gfp", "rms"):
if key in valid:
Expand Down

0 comments on commit 916ec19

Please sign in to comment.