-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy path_mpl_figure.py
2530 lines (2348 loc) · 103 KB
/
_mpl_figure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Figure classes for MNE-Python's 2D plots.
Class Hierarchy
---------------
MNEFigParams Container object, attached to MNEFigure by default. Sets
close_key='escape' plus whatever other key-value pairs are
passed to its constructor.
matplotlib.figure.Figure
└ MNEFigure
├ MNEBrowseFigure Interactive figure for scrollable data.
│ Generated by:
│ - raw.plot()
│ - epochs.plot()
│ - ica.plot_sources(raw)
│ - ica.plot_sources(epochs)
│
├ MNEAnnotationFigure GUI for adding annotations to Raw
│
├ MNESelectionFigure GUI for spatial channel selection. raw.plot()
│ and epochs.plot() will generate one of these
│ alongside an MNEBrowseFigure when
│ group_by == 'selection' or 'position'
│
└ MNELineFigure Interactive figure for non-scrollable data.
Generated by:
- spectrum.plot()
- evoked.plot() TODO Not yet implemented
- evoked.plot_white() TODO Not yet implemented
- evoked.plot_joint() TODO Not yet implemented
"""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import datetime
import platform
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import get_backend
from matplotlib.figure import Figure
from .._fiff.pick import (
_DATA_CH_TYPES_ORDER_DEFAULT,
_DATA_CH_TYPES_SPLIT,
_EYETRACK_CH_TYPES_SPLIT,
_FNIRS_CH_TYPES_SPLIT,
_VALID_CHANNEL_TYPES,
channel_indices_by_type,
pick_types,
)
from ..fixes import _close_event
from ..utils import Bunch, _click_ch_name, check_version, logger
from ._figure import BrowserBase
from .utils import (
DraggableLine,
_events_off,
_fake_click,
_fake_keypress,
_fake_scroll,
_merge_annotations,
_set_window_title,
_validate_if_list_of_axes,
plot_sensors,
plt_show,
)
name = "matplotlib"
BACKEND = get_backend()
# CONSTANTS (inches)
ANNOTATION_FIG_PAD = 0.1
ANNOTATION_FIG_MIN_H = 2.9 # fixed part, not including radio buttons/labels
ANNOTATION_FIG_W = 5.0
ANNOTATION_FIG_CHECKBOX_COLUMN_W = 0.5
_OLD_BUTTONS = not check_version("matplotlib", "3.7")
class MNEFigure(Figure):
"""Base class for 2D figures & dialogs; wraps matplotlib.figure.Figure."""
def __init__(self, **kwargs):
from matplotlib import rcParams
# figsize is the only kwarg we pass to matplotlib Figure()
figsize = kwargs.pop("figsize", None)
super().__init__(figsize=figsize)
# things we'll almost always want
defaults = dict(
fgcolor=rcParams["axes.edgecolor"], bgcolor=rcParams["axes.facecolor"]
)
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
# add param object if not already added (e.g. by BrowserBase)
if not hasattr(self, "mne"):
from mne.viz._figure import BrowserParams
self.mne = BrowserParams(**kwargs)
else:
for key in [k for k in kwargs if not hasattr(self.mne, k)]:
setattr(self.mne, key, kwargs[key])
def _close(self, event=None):
"""Handle close events."""
logger.debug(f"Closing {self!r}")
# remove references from parent fig to child fig
is_child = getattr(self.mne, "parent_fig", None) is not None
is_named = getattr(self.mne, "fig_name", None) is not None
if is_child:
try:
self.mne.parent_fig.mne.child_figs.remove(self)
except ValueError:
pass # already removed (on its own, probably?)
if is_named:
setattr(self.mne.parent_fig.mne, self.mne.fig_name, None)
def _keypress(self, event):
"""Handle keypress events."""
if event.key == self.mne.close_key:
plt.close(self)
elif event.key == "f11": # full screen
self.canvas.manager.full_screen_toggle()
def _buttonpress(self, event):
"""Handle buttonpress events."""
pass
def _pick(self, event):
"""Handle matplotlib pick events."""
pass
def _resize(self, event):
"""Handle window resize events."""
pass
def _add_default_callbacks(self, **kwargs):
"""Remove some matplotlib default callbacks and add MNE-Python ones."""
# Remove matplotlib default keypress catchers
default_callbacks = list(
self.canvas.callbacks.callbacks.get("key_press_event", {})
)
for callback in default_callbacks:
self.canvas.callbacks.disconnect(callback)
# add our event callbacks
callbacks = dict(
resize_event=self._resize,
key_press_event=self._keypress,
button_press_event=self._buttonpress,
close_event=self._close,
pick_event=self._pick,
)
callbacks.update(kwargs)
callback_ids = dict()
for event, callback in callbacks.items():
callback_ids[event] = self.canvas.mpl_connect(event, callback)
# store callback references so they aren't garbage-collected
self.mne._callback_ids = callback_ids
def _get_dpi_ratio(self):
"""Get DPI ratio (to handle hi-DPI screens)."""
dpi_ratio = 1.0
for key in ("_dpi_ratio", "_device_scale"):
dpi_ratio = getattr(self.canvas, key, dpi_ratio)
return dpi_ratio
def _get_size_px(self):
"""Get figure size in pixels."""
dpi_ratio = self._get_dpi_ratio()
return self.get_size_inches() * self.dpi / dpi_ratio
def _inch_to_rel(self, dim_inches, horiz=True):
"""Convert inches to figure-relative distances."""
fig_w, fig_h = self.get_size_inches()
w_or_h = fig_w if horiz else fig_h
return dim_inches / w_or_h
class MNEAnnotationFigure(MNEFigure):
"""Interactive dialog figure for annotations."""
def _close(self, event=None):
"""Handle close events (via keypress or window [x])."""
parent = self.mne.parent_fig
# disable span selector
parent.mne.ax_main.selector.active = False
# clear hover line
parent._remove_annotation_hover_line()
# disconnect hover callback
callback_id = parent.mne._callback_ids["motion_notify_event"]
parent.canvas.callbacks.disconnect(callback_id)
# do all the other cleanup activities
super()._close(event)
def _keypress(self, event):
"""Handle keypress events."""
text = self.label.get_text()
key = event.key
if key == self.mne.close_key:
plt.close(self)
elif key == "backspace":
text = text[:-1]
elif key == "enter":
self.mne.parent_fig._add_annotation_label(event)
return
elif len(key) > 1 or key == ";": # ignore modifier keys
return
else:
text = text + key
self.label.set_text(text)
self.canvas.draw()
def _radiopress(self, event, *, draw=True):
"""Handle Radiobutton clicks for Annotation label selection."""
# update which button looks active
buttons = self.mne.radio_ax.buttons
labels = [label.get_text() for label in buttons.labels]
idx = labels.index(buttons.value_selected)
self._set_active_button(idx, draw=False)
# update click-drag rectangle color
color = self.mne.parent_fig.mne.annotation_segment_colors[labels[idx]]
selector = self.mne.parent_fig.mne.ax_main.selector
# https://github.com/matplotlib/matplotlib/issues/20618
# https://github.com/matplotlib/matplotlib/pull/20693
selector.set_props(color=color, facecolor=color)
if draw:
self.canvas.draw()
def _click_override(self, event):
"""Override MPL radiobutton click detector to use transData."""
assert _OLD_BUTTONS
ax = self.mne.radio_ax
buttons = ax.buttons
if buttons.ignore(event) or event.button != 1 or event.inaxes != ax:
return
pclicked = ax.transData.inverted().transform((event.x, event.y))
distances = {}
for i, (p, t) in enumerate(zip(buttons.circles, buttons.labels)):
if (
t.get_window_extent().contains(event.x, event.y)
or np.linalg.norm(pclicked - p.center) < p.radius
):
distances[i] = np.linalg.norm(pclicked - p.center)
if len(distances) > 0:
closest = min(distances, key=distances.get)
buttons.set_active(closest)
def _set_active_button(self, idx, *, draw=True):
"""Set active button in annotation dialog figure."""
buttons = self.mne.radio_ax.buttons
logger.debug(f"buttons: {buttons}")
logger.debug(f"active idx: {idx}")
with _events_off(buttons):
buttons.set_active(idx)
if _OLD_BUTTONS:
logger.debug(f"circles: {buttons.circles}")
for circle in buttons.circles:
circle.set_facecolor(self.mne.parent_fig.mne.bgcolor)
# active circle gets filled in, partially transparent
color = list(buttons.circles[idx].get_edgecolor())
logger.debug(f"color: {color}")
color[-1] = 0.5
buttons.circles[idx].set_facecolor(color)
if draw:
self.canvas.draw()
class MNESelectionFigure(MNEFigure):
"""Interactive dialog figure for channel selections."""
def _close(self, event=None):
"""Handle close events."""
self.mne.parent_fig.mne.child_figs.remove(self)
self.mne.fig_selection = None
# selection fig & main fig tightly integrated; closing one closes both
plt.close(self.mne.parent_fig)
def _keypress(self, event):
"""Handle keypress events."""
if event.key in ("up", "down", "b"):
self.mne.parent_fig._keypress(event)
else: # check for close key
super()._keypress(event)
def _radiopress(self, event):
"""Handle RadioButton clicks for channel selection groups."""
logger.debug(f"Got radio press: {repr(event)}")
selections_dict = self.mne.parent_fig.mne.ch_selections
buttons = self.mne.radio_ax.buttons
labels = [label.get_text() for label in buttons.labels]
this_label = buttons.value_selected
parent = self.mne.parent_fig
if this_label == "Custom" and not len(selections_dict["Custom"]):
with _events_off(buttons):
buttons.set_active(self.mne.old_selection)
return
# clicking a selection cancels butterfly mode
if parent.mne.butterfly:
logger.debug("Disabling butterfly mode")
parent._toggle_butterfly()
with _events_off(buttons):
buttons.set_active(labels.index(this_label))
parent._update_selection()
def _set_custom_selection(self):
"""Set custom selection by lasso selector."""
chs = self.lasso.selection
parent = self.mne.parent_fig
buttons = self.mne.radio_ax.buttons
if not len(chs):
return
labels = [label.get_text() for label in buttons.labels]
inds = np.isin(parent.mne.ch_names, chs)
parent.mne.ch_selections["Custom"] = inds.nonzero()[0]
buttons.set_active(labels.index("Custom"))
def _style_radio_buttons_butterfly(self):
"""Handle RadioButton state for keyboard interactions."""
# Show all radio buttons as selected when in butterfly mode
parent = self.mne.parent_fig
buttons = self.mne.radio_ax.buttons
color = buttons.activecolor if parent.mne.butterfly else parent.mne.bgcolor
if _OLD_BUTTONS:
for circle in buttons.circles:
circle.set_facecolor(color)
# when leaving butterfly mode, make most-recently-used selection active
if not parent.mne.butterfly:
with _events_off(buttons):
buttons.set_active(self.mne.old_selection)
# update the sensors too
parent._update_highlighted_sensors()
class MNEBrowseFigure(BrowserBase, MNEFigure):
"""Interactive figure with scrollbars, for data browsing."""
def __init__(self, inst, figsize, ica=None, xlabel="Time (s)", **kwargs):
from matplotlib.colors import to_rgba_array
from matplotlib.patches import Rectangle
from matplotlib.ticker import (
FixedFormatter,
FixedLocator,
FuncFormatter,
NullFormatter,
)
from matplotlib.transforms import blended_transform_factory
from matplotlib.widgets import Button
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from mpl_toolkits.axes_grid1.axes_size import Fixed
self.backend_name = "matplotlib"
kwargs.update({"inst": inst, "figsize": figsize, "ica": ica, "xlabel": xlabel})
BrowserBase.__init__(self, **kwargs)
MNEFigure.__init__(self, **kwargs)
# MAIN AXES: default sizes (inches)
# XXX simpler with constrained_layout? (when it's no longer "beta")
l_margin = 1.0
r_margin = 0.1
b_margin = 0.45
t_margin = 0.25
scroll_width = 0.25
hscroll_dist = 0.25
vscroll_dist = 0.1
help_width = scroll_width * 2
# MAIN AXES: default margins (figure-relative coordinates)
left = self._inch_to_rel(l_margin - vscroll_dist - help_width)
right = 1 - self._inch_to_rel(r_margin)
bottom = self._inch_to_rel(b_margin, horiz=False)
top = 1 - self._inch_to_rel(t_margin, horiz=False)
width = right - left
height = top - bottom
position = [left, bottom, width, height]
# Main axes must be a subplot for subplots_adjust to work (so user can
# adjust margins). That's why we don't use the Divider class directly.
ax_main = self.add_subplot(1, 1, 1, position=position)
self.subplotpars.update(left=left, bottom=bottom, top=top, right=right)
div = make_axes_locatable(ax_main)
# this only gets shown in zen mode
self.mne.zen_xlabel = ax_main.set_xlabel(xlabel)
self.mne.zen_xlabel.set_visible(not self.mne.scrollbars_visible)
# make sure background color of the axis is set
if "bgcolor" in kwargs:
ax_main.set_facecolor(kwargs["bgcolor"])
# SCROLLBARS
ax_hscroll = div.append_axes(
position="bottom", size=Fixed(scroll_width), pad=Fixed(hscroll_dist)
)
ax_vscroll = div.append_axes(
position="right", size=Fixed(scroll_width), pad=Fixed(vscroll_dist)
)
ax_hscroll.get_yaxis().set_visible(False)
ax_hscroll.set_xlabel(xlabel)
ax_vscroll.set_axis_off()
# HORIZONTAL SCROLLBAR PATCHES (FOR MARKING BAD EPOCHS)
if self.mne.is_epochs:
epoch_nums = self.mne.inst.selection
for ix, _ in enumerate(epoch_nums):
start = self.mne.boundary_times[ix]
width = np.diff(self.mne.boundary_times[:2])[0]
ax_hscroll.add_patch(
Rectangle(
(start, 0),
width,
1,
color="none",
zorder=self.mne.zorder["patch"],
)
)
# both axes, major ticks: gridlines
for _ax in (ax_main, ax_hscroll):
_ax.xaxis.set_major_locator(FixedLocator(self.mne.boundary_times[1:-1]))
_ax.xaxis.set_major_formatter(NullFormatter())
grid_kwargs = dict(
color=self.mne.fgcolor, axis="x", zorder=self.mne.zorder["grid"]
)
ax_main.grid(linewidth=2, linestyle="dashed", **grid_kwargs)
ax_hscroll.grid(alpha=0.5, linewidth=0.5, linestyle="solid", **grid_kwargs)
# main axes, minor ticks: ticklabel (epoch number) for every epoch
ax_main.xaxis.set_minor_locator(FixedLocator(self.mne.midpoints))
ax_main.xaxis.set_minor_formatter(FixedFormatter(epoch_nums))
# hscroll axes, minor ticks: up to 20 ticklabels (epoch numbers)
ax_hscroll.xaxis.set_minor_locator(
FixedLocator(self.mne.midpoints, nbins=20)
)
ax_hscroll.xaxis.set_minor_formatter(
FuncFormatter(lambda x, pos: self._get_epoch_num_from_time(x))
)
# hide some ticks
ax_main.tick_params(axis="x", which="major", bottom=False)
ax_hscroll.tick_params(axis="x", which="both", bottom=False)
else:
# RAW / ICA X-AXIS TICK & LABEL FORMATTING
ax_main.xaxis.set_major_formatter(
FuncFormatter(partial(self._xtick_formatter, ax_type="main"))
)
ax_hscroll.xaxis.set_major_formatter(
FuncFormatter(partial(self._xtick_formatter, ax_type="hscroll"))
)
if self.mne.time_format != "float":
for _ax in (ax_main, ax_hscroll):
_ax.set_xlabel("Time (HH:MM:SS)")
# VERTICAL SCROLLBAR PATCHES (COLORED BY CHANNEL TYPE)
ch_order = self.mne.ch_order
for ix, pick in enumerate(ch_order):
this_color = (
self.mne.ch_color_bad
if self.mne.ch_names[pick] in self.mne.info["bads"]
else self.mne.ch_color_dict
)
if isinstance(this_color, dict):
this_color = this_color[self.mne.ch_types[pick]]
ax_vscroll.add_patch(
Rectangle(
(0, ix), 1, 1, color=this_color, zorder=self.mne.zorder["patch"]
)
)
ax_vscroll.set_ylim(len(ch_order), 0)
ax_vscroll.set_visible(not self.mne.butterfly)
# SCROLLBAR VISIBLE SELECTION PATCHES
sel_kwargs = dict(
alpha=0.3, linewidth=4, clip_on=False, edgecolor=self.mne.fgcolor
)
vsel_patch = Rectangle(
(0, 0), 1, self.mne.n_channels, facecolor=self.mne.bgcolor, **sel_kwargs
)
ax_vscroll.add_patch(vsel_patch)
hsel_facecolor = np.average(
np.vstack(
(to_rgba_array(self.mne.fgcolor), to_rgba_array(self.mne.bgcolor))
),
axis=0,
weights=(3, 1),
) # 75% foreground, 25% background
hsel_patch = Rectangle(
(self.mne.t_start, 0),
self.mne.duration,
1,
facecolor=hsel_facecolor,
**sel_kwargs,
)
ax_hscroll.add_patch(hsel_patch)
ax_hscroll.set_xlim(
self.mne.first_time,
self.mne.first_time + self.mne.n_times / self.mne.info["sfreq"],
)
# VLINE
vline_color = (0.0, 0.75, 0.0)
vline_kwargs = dict(visible=False, zorder=self.mne.zorder["vline"])
if self.mne.is_epochs:
x = np.arange(self.mne.n_epochs)
vline = ax_main.vlines(x, 0, 1, colors=vline_color, **vline_kwargs)
vline.set_transform(
blended_transform_factory(ax_main.transData, ax_main.transAxes)
)
vline_hscroll = None
else:
vline = ax_main.axvline(0, color=vline_color, **vline_kwargs)
vline_hscroll = ax_hscroll.axvline(0, color=vline_color, **vline_kwargs)
vline_text = ax_main.annotate(
"",
xy=(0, 0),
xycoords="axes fraction",
xytext=(-2, 0),
textcoords="offset points",
fontsize=10,
ha="right",
va="center",
color=vline_color,
**vline_kwargs,
)
# HELP BUTTON: initialize in the wrong spot...
ax_help = div.append_axes(
position="left", size=Fixed(help_width), pad=Fixed(vscroll_dist)
)
# HELP BUTTON: ...move it down by changing its locator
loc = div.new_locator(nx=0, ny=0)
ax_help.set_axes_locator(loc)
# HELP BUTTON: make it a proper button
with _patched_canvas(ax_help.figure):
self.mne.button_help = Button(ax_help, "Help")
# PROJ BUTTON
ax_proj = None
if len(self.mne.projs) and not self.mne.inst.proj:
proj_button_pos = [
1 - self._inch_to_rel(r_margin + scroll_width), # left
self._inch_to_rel(b_margin, horiz=False), # bottom
self._inch_to_rel(scroll_width), # width
self._inch_to_rel(scroll_width, horiz=False), # height
]
loc = div.new_locator(nx=4, ny=0)
ax_proj = self.add_axes(proj_button_pos)
ax_proj.set_axes_locator(loc)
with _patched_canvas(ax_help.figure):
self.mne.button_proj = Button(ax_proj, "Prj")
# INIT TRACES
self.mne.trace_kwargs = dict(antialiased=True, linewidth=0.5)
self.mne.traces = ax_main.plot(
np.full((1, self.mne.n_channels), np.nan), **self.mne.trace_kwargs
)
# SAVE UI ELEMENT HANDLES
vars(self.mne).update(
ax_main=ax_main,
ax_help=ax_help,
ax_proj=ax_proj,
ax_hscroll=ax_hscroll,
ax_vscroll=ax_vscroll,
vsel_patch=vsel_patch,
hsel_patch=hsel_patch,
vline=vline,
vline_hscroll=vline_hscroll,
vline_text=vline_text,
)
def _get_size(self):
return self.get_size_inches()
def _resize(self, event):
"""Handle resize event for mne_browse-style plots (Raw/Epochs/ICA)."""
old_width, old_height = self.mne.fig_size_px
new_width, new_height = self._get_size_px()
new_margins = _calc_new_margins(
self, old_width, old_height, new_width, new_height
)
self.subplots_adjust(**new_margins)
# zen mode bookkeeping
self.mne.zen_w *= old_width / new_width
self.mne.zen_h *= old_height / new_height
self.mne.fig_size_px = (new_width, new_height)
self.canvas.draw_idle()
def _hover(self, event):
"""Handle motion event when annotating."""
if (
event.button is not None
or event.xdata is None
or event.inaxes != self.mne.ax_main
):
return
if not self.mne.draggable_annotations:
self._remove_annotation_hover_line()
return
from matplotlib.patheffects import Normal, Stroke
for coll in self.mne.annotations:
if coll.contains(event)[0]:
path = coll.get_paths()
assert len(path) == 1
path = path[0]
color = coll.get_edgecolors()[0]
ylim = self.mne.ax_main.get_ylim()
# are we on the left or right edge?
_l = path.vertices[:, 0].min()
_r = path.vertices[:, 0].max()
x = _l if abs(event.xdata - _l) < abs(event.xdata - _r) else _r
mask = path.vertices[:, 0] == x
def drag_callback(x0):
path.vertices[mask, 0] = x0
# create or update the DraggableLine
hover_line = self.mne.annotation_hover_line
if hover_line is None:
line = self.mne.ax_main.plot(
[x, x], ylim, color=color, linewidth=2, pickradius=5.0
)[0]
hover_line = DraggableLine(
line, self._modify_annotation, drag_callback
)
else:
hover_line.set_x(x)
hover_line.drag_callback = drag_callback
# style the line
line = hover_line.line
patheff = [Stroke(linewidth=4, foreground=color, alpha=0.5), Normal()]
line.set_path_effects(
patheff if line.contains(event)[0] else patheff[1:]
)
self.mne.ax_main.selector.active = False
self.mne.annotation_hover_line = hover_line
self.canvas.draw_idle()
return
self._remove_annotation_hover_line()
def _keypress(self, event):
"""Handle keypress events."""
key = event.key
n_channels = self.mne.n_channels
if self.mne.is_epochs:
last_time = self.mne.n_times / self.mne.info["sfreq"]
else:
last_time = self.mne.inst.times[-1]
# scroll up/down
if key in ("down", "up", "shift+down", "shift+up"):
key = key.split("+")[-1]
direction = -1 if key == "up" else 1
# butterfly case
if self.mne.butterfly:
return
# group_by case
elif self.mne.fig_selection is not None:
buttons = self.mne.fig_selection.mne.radio_ax.buttons
labels = [label.get_text() for label in buttons.labels]
current_label = buttons.value_selected
current_idx = labels.index(current_label)
selections_dict = self.mne.ch_selections
penult = current_idx < (len(labels) - 1)
pre_penult = current_idx < (len(labels) - 2)
has_custom = selections_dict.get("Custom", None) is not None
def_custom = len(selections_dict.get("Custom", list()))
up_ok = key == "up" and current_idx > 0
down_ok = key == "down" and (
pre_penult
or (penult and not has_custom)
or (penult and has_custom and def_custom)
)
if up_ok or down_ok:
buttons.set_active(current_idx + direction)
# normal case
else:
ceiling = len(self.mne.ch_order) - n_channels
ch_start = self.mne.ch_start + direction * n_channels
self.mne.ch_start = np.clip(ch_start, 0, ceiling)
self._update_picks()
self._update_vscroll()
self._redraw()
# scroll left/right
elif key in ("right", "left", "shift+right", "shift+left"):
old_t_start = self.mne.t_start
direction = 1 if key.endswith("right") else -1
if self.mne.is_epochs:
denom = 1 if key.startswith("shift") else self.mne.n_epochs
else:
denom = 1 if key.startswith("shift") else 4
t_max = last_time - self.mne.duration
t_start = self.mne.t_start + direction * self.mne.duration / denom
self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max)
if self.mne.t_start != old_t_start:
self._update_hscroll()
self._redraw(annotations=True)
# scale traces
elif key in ("=", "+", "-"):
scaler = 1 / 1.1 if key == "-" else 1.1
self.mne.scale_factor *= scaler
self._redraw(update_data=False)
# change number of visible channels
elif (
key in ("pageup", "pagedown")
and self.mne.fig_selection is None
and not self.mne.butterfly
):
new_n_ch = n_channels + (1 if key == "pageup" else -1)
self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order))
# add new chs from above if we're at the bottom of the scrollbar
ch_end = self.mne.ch_start + self.mne.n_channels
if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0:
self.mne.ch_start -= 1
self._update_vscroll()
# redraw only if changed
if self.mne.n_channels != n_channels:
self._update_picks()
self._update_trace_offsets()
self._redraw(annotations=True)
# change duration
elif key in ("home", "end"):
old_dur = self.mne.duration
dur_delta = 1 if key == "end" else -1
if self.mne.is_epochs:
# prevent from showing zero epochs, or more epochs than we have
self.mne.n_epochs = np.clip(
self.mne.n_epochs + dur_delta, 1, len(self.mne.inst)
)
# use the length of one epoch as duration change
min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"]
new_dur = self.mne.duration + dur_delta * min_dur
else:
# never show fewer than 3 samples
min_dur = 3 * np.diff(self.mne.inst.times[:2])[0]
# use multiplicative dur_delta
dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5
new_dur = self.mne.duration * dur_delta
self.mne.duration = np.clip(new_dur, min_dur, last_time)
if self.mne.duration != old_dur:
if self.mne.t_start + self.mne.duration > last_time:
self.mne.t_start = last_time - self.mne.duration
self._update_hscroll()
self._redraw(annotations=True)
elif key == "?": # help window
self._toggle_help_fig(event)
elif key == "a": # annotation mode
self._toggle_annotation_fig()
elif key == "b" and self.mne.instance_type != "ica": # butterfly mode
self._toggle_butterfly()
elif key == "d": # DC shift
self.mne.remove_dc = not self.mne.remove_dc
self._redraw()
elif key == "h": # histogram
self._toggle_epoch_histogram()
elif key == "j" and len(self.mne.projs): # SSP window
self._toggle_proj_fig()
elif key == "J" and len(self.mne.projs):
self._toggle_proj_checkbox(event, toggle_all=True)
elif key == "p": # toggle draggable annotations
self._toggle_draggable_annotations(event)
if self.mne.fig_annotation is not None:
checkbox = self.mne.fig_annotation.mne.drag_checkbox
with _events_off(checkbox):
checkbox.set_active(0)
elif key == "s": # scalebars
self._toggle_scalebars(event)
elif key == "w": # toggle noise cov whitening
self._toggle_whitening()
elif key == "z": # zen mode: hide scrollbars and buttons
self._toggle_scrollbars()
self._redraw(update_data=False)
elif key == "t":
self._toggle_time_format()
else: # check for close key / fullscreen toggle
super()._keypress(event)
def _buttonpress(self, event):
"""Handle mouse clicks."""
from matplotlib.collections import PolyCollection
from ..annotations import _sync_onset
butterfly = self.mne.butterfly
annotating = self.mne.fig_annotation is not None
ax_main = self.mne.ax_main
inst = self.mne.inst
# ignore middle clicks, scroll wheel events, and clicks outside axes
if event.button not in (1, 3) or event.inaxes is None:
return
elif event.button == 1: # left-click (primary)
# click in main axes
if event.inaxes == ax_main and not annotating:
if self.mne.instance_type == "epochs" or not butterfly:
for line in self.mne.traces + self.mne.epoch_traces:
if line.contains(event)[0]:
if self.mne.instance_type == "epochs":
self._toggle_bad_epoch(event)
else:
idx = self.mne.traces.index(line)
self._toggle_bad_channel(idx)
return
self._show_vline(event.xdata) # butterfly / not on data trace
self._redraw(update_data=False, annotations=False)
return
# click in vertical scrollbar
elif event.inaxes == self.mne.ax_vscroll:
if self.mne.fig_selection is not None:
self._change_selection_vscroll(event)
elif self._check_update_vscroll_clicked(event):
self._redraw()
# click in horizontal scrollbar
elif event.inaxes == self.mne.ax_hscroll:
if self._check_update_hscroll_clicked(event):
self._redraw(annotations=True)
# click on proj button
elif event.inaxes == self.mne.ax_proj:
self._toggle_proj_fig(event)
# click on help button
elif event.inaxes == self.mne.ax_help:
self._toggle_help_fig(event)
else: # right-click (secondary)
if annotating:
spans = [
span
for span in ax_main.collections
if isinstance(span, PolyCollection)
]
if any(span.contains(event)[0] for span in spans):
xdata = event.xdata - self.mne.first_time
start = _sync_onset(inst, inst.annotations.onset)
end = start + inst.annotations.duration
is_onscreen = self.mne.onscreen_annotations # boolean array
was_clicked = (xdata > start) & (xdata < end) & is_onscreen
# determine which annotation label is "selected"
buttons = self.mne.fig_annotation.mne.radio_ax.buttons
current_label = buttons.value_selected
is_active_label = inst.annotations.description == current_label
# use z-order as tiebreaker (or if click wasn't on an active span)
# (ax_main.collections only includes *visible* annots, so we offset)
visible_zorders = [span.zorder for span in spans]
zorders = np.zeros_like(is_onscreen).astype(int)
offset = np.where(is_onscreen)[0][0]
zorders[offset : (offset + len(visible_zorders))] = visible_zorders
# among overlapping clicked spans, prefer removing spans whose label
# is the active label; then fall back to zorder as deciding factor
active_clicked = was_clicked & is_active_label
mask = active_clicked if any(active_clicked) else was_clicked
highest = zorders == zorders[mask].max()
idx = np.where(highest)[0]
inst.annotations.delete(idx)
self._remove_annotation_hover_line()
self._draw_annotations()
self.canvas.draw_idle()
elif event.inaxes == ax_main:
self._toggle_vline(False)
def _pick(self, event):
"""Handle matplotlib pick events."""
from matplotlib.text import Text
if self.mne.butterfly:
return
# clicked on channel name
if isinstance(event.artist, Text):
ch_name = event.artist.get_text()
ind = self.mne.ch_names[self.mne.picks].tolist().index(ch_name)
if event.mouseevent.button == 1: # left click
self._toggle_bad_channel(ind)
elif event.mouseevent.button == 3: # right click
self._create_ch_context_fig(ind)
def _create_ch_context_fig(self, idx):
fig = super()._create_ch_context_fig(idx)
plt_show(fig=fig)
def _new_child_figure(self, fig_name, *, layout=None, **kwargs):
"""Instantiate a new MNE dialog figure (with event listeners)."""
fig = _figure(
toolbar=False,
parent_fig=self,
fig_name=fig_name,
layout=layout,
**kwargs,
)
fig._add_default_callbacks()
self.mne.child_figs.append(fig)
if isinstance(fig_name, str):
setattr(self.mne, fig_name, fig)
return fig
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# HELP DIALOG
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def _create_help_fig(self):
"""Create help dialog window."""
text = {
key: val for key, val in self._get_help_text().items() if val is not None
}
keys = ""
vals = ""
for key, val in text.items():
newsection = "\n" if key.startswith("_") else ""
key = key[1:] if key.startswith("_") else key
newlines = "\n" * len(val.split("\n")) # handle multiline values
keys += f"{newsection}{key} {newlines}"
vals += f"{newsection}{val}\n"
# calc figure size
n_lines = len(keys.split("\n"))
longest_key = max(len(k) for k in text.keys())
longest_val = max(
max(len(w) for w in v.split("\n")) if "\n" in v else len(v)
for v in text.values()
)
width = (longest_key + longest_val) / 12
height = (n_lines) / 5
# create figure and axes
fig = self._new_child_figure(
figsize=(width, height), fig_name="fig_help", window_title="Help"
)
ax = fig.add_axes((0.01, 0.01, 0.98, 0.98))
ax.set_axis_off()
kwargs = dict(va="top", linespacing=1.5, usetex=False)
ax.text(0.42, 1, keys, ma="right", ha="right", **kwargs)
ax.text(0.42, 1, vals, ma="left", ha="left", **kwargs)
def _toggle_help_fig(self, event):
"""Show/hide the help dialog window."""
if self.mne.fig_help is None:
self._create_help_fig()
plt_show(fig=self.mne.fig_help)
else:
plt.close(self.mne.fig_help)
def _get_help_text(self):
"""Generate help dialog text; `None`-valued entries removed later."""
inst = self.mne.instance_type
is_raw = inst == "raw"
is_epo = inst == "epochs"
is_ica = inst == "ica"
has_proj = bool(len(self.mne.projs))
# adapt keys to different platforms
is_mac = platform.system() == "Darwin"
dur_keys = ("fn + ←", "fn + →") if is_mac else ("Home", "End")
ch_keys = ("fn + ↑", "fn + ↓") if is_mac else ("Page up", "Page down")
# adapt descriptions to different instance types
ch_cmp = "component" if is_ica else "channel"
ch_epo = "epoch" if is_epo else "channel"
ica_bad = "Mark/unmark component for exclusion"
dur_vals = (
[f"Show {n} epochs" for n in ("fewer", "more")]
if self.mne.is_epochs
else [f"Show {d} time window" for d in ("shorter", "longer")]
)
ch_vals = [
f"{inc_dec} number of visible {ch_cmp}s"
for inc_dec in ("Increase", "Decrease")
]
lclick_data = ica_bad if is_ica else f"Mark/unmark bad {ch_epo}"
lclick_name = ica_bad if is_ica else "Mark/unmark bad channel"
rclick_name = dict(
ica="Show diagnostics for component",
epochs="Show imageplot for channel",
raw="Show channel location",
)[inst]
# TODO not yet implemented
# ldrag = ('Show spectrum plot for selected time span;\nor (in '
# 'annotation mode) add annotation') if inst== 'raw' else None
ldrag = "add annotation (in annotation mode)" if is_raw else None
noise_cov = None if self.mne.noise_cov is None else "Toggle signal whitening"
scrl = "1 epoch" if self.mne.is_epochs else "¼ window"
# below, value " " is a hack to make "\n".split(value) have length 1
help_text = OrderedDict(
[
("_NAVIGATION", " "),
("→", f"Scroll {scrl} right (scroll full window with Shift + →)"),
("←", f"Scroll {scrl} left (scroll full window with Shift + ←)"),
(dur_keys[0], dur_vals[0]),
(dur_keys[1], dur_vals[1]),
("↑", f"Scroll up ({ch_cmp}s)"),
("↓", f"Scroll down ({ch_cmp}s)"),
(ch_keys[0], ch_vals[0]),
(ch_keys[1], ch_vals[1]),
("_SIGNAL TRANSFORMATIONS", " "),
("+ or =", "Increase signal scaling"),
("-", "Decrease signal scaling"),
("b", "Toggle butterfly mode" if not is_ica else None),
("d", "Toggle DC removal" if is_raw else None),
("w", noise_cov),
("_USER INTERFACE", " "),
("a", "Toggle annotation mode" if is_raw else None),
("h", "Toggle peak-to-peak histogram" if is_epo else None),
("j", "Toggle SSP projector window" if has_proj else None),
("shift+j", "Toggle all SSPs"),
("p", "Toggle draggable annotations" if is_raw else None),
("s", "Toggle scalebars" if not is_ica else None),
("z", "Toggle scrollbars"),
("t", "Toggle time format" if not is_epo else None),
("F11", "Toggle fullscreen" if not is_mac else None),
("?", "Open this help window"),
("esc", "Close focused figure or dialog window"),
("_MOUSE INTERACTION", " "),