-
Notifications
You must be signed in to change notification settings - Fork 34
/
epochs_multivariate.py
1406 lines (1190 loc) · 58.9 KB
/
epochs_multivariate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Authors: Thomas Samuel Binns <t.s.binns@outlook.com>
# Tien Dung Nguyen <>
# Richard M. Köhler <koehler.richard@charite.de>
#
# License: BSD (3-clause)
import inspect
import copy
import numpy as np
from mne import BaseEpochs
from mne.parallel import parallel_func
from mne.utils import _arange_div, logger, ProgressBar, _time_mask
from ..base import (
MultivariateSpectralConnectivity, MultivariateSpectroTemporalConnectivity
)
from .epochs import (
_assemble_spectral_params, _check_estimators, _compute_freqs,
_compute_freq_mask, _epoch_spectral_connectivity,
_get_and_verify_data_sizes, _get_n_epochs, _prepare_connectivity
)
class _MVCSpectralEpochs():
"""Computes multivariate spectral connectivity of epoched data for the
multivariate_spectral_connectivity_epochs function."""
init_attrs = [
'data', 'indices', 'names', 'method', 'sfreq', 'mode', 'tmin', 'tmax',
'fmin', 'fmax', 'fskip', 'faverage', 'cwt_freqs', 'mt_bandwidth',
'mt_adaptive', 'mt_low_bias', 'cwt_n_cycles', 'n_components',
'gc_n_lags', 'block_size', 'n_jobs', 'verbose'
]
gc_method_aliases = ['gc', 'net_gc', 'trgc', 'net_trgc']
gc_method_names = ['GC', 'Net GC', 'TRGC', 'Net TRGC']
# possible forms of GC with information on how to use the methods
possible_gc_forms = {
'seeds -> targets': dict(
flip_seeds_targets=False, reverse_time=False,
for_methods=['GC', 'Net GC', 'TRGC', 'Net TRGC'], method_class=None
),
'targets -> seeds': dict(
flip_seeds_targets=True, reverse_time=False,
for_methods=['Net GC', 'Net TRGC'], method_class=None
),
'time-reversed[seeds -> targets]': dict(
flip_seeds_targets=False, reverse_time=True,
for_methods=['TRGC', 'Net TRGC'], method_class=None
),
'time-reversed[targets -> seeds]': dict(
flip_seeds_targets=True, reverse_time=True,
for_methods=['Net TRGC'], method_class=None
)
}
# possible forms of coherence with information on how to use the methods
possible_coh_forms = {
'MIC & MIM': dict(
for_methods=['MIC', 'MIM'], exclude_methods=[], method_class=None
),
'MIC': dict(
for_methods=['MIC'], exclude_methods=['MIM'], method_class=None
),
'MIM': dict(
for_methods=['MIM'], exclude_methods=['MIC'], method_class=None
)
}
# threshold for classifying singular values as being non-zero
rank_nonzero_tol = 1e-5
perform_svd = False
# whether the requested frequencies are discontinuous (e.g. different bands)
discontinuous_freqs = False
# whether or not GC must be computed separately from other methods
compute_gc_separately = False
# storage for GC results if SVD used or freqs are discontinuous (which must
# be computed separately from other methods)
separate_gc_method_types = []
separate_gc_con = []
separate_gc_topo = []
# storage for coherence results (and GC results if no SVD used and requested
# frequencies are continuous)
remaining_con = []
remaining_topo = []
def __init__(self, **kwargs):
assert all(attr in self.init_attrs for attr in kwargs.keys()), (
'Not all inputs to the _MVCSpectralEpochs class have been '
'provided. Please contact the mne-connectivity developers.'
)
for name, value in kwargs.items():
assert name in self.init_attrs, (
'An input to the _MVCSpectralEpochs class is not recognised. '
'Please contact the mne-connectivity developers.'
)
setattr(self, name, value)
self._sort_inputs()
def _sort_inputs(self):
"""Checks the format of the input parameters and enacts them to create
new object attributes."""
self._sort_parallelisation_inputs()
self._sort_data_info()
self._sort_estimator_inputs()
self._sort_freq_inputs()
self._sort_indices_inputs()
self._sort_svd_inputs()
if self.perform_svd or self.discontinuous_freqs:
self.compute_gc_separately = True
def _sort_parallelisation_inputs(self):
"""Establishes parallelisation of the function for computing the CSD if
n_jobs > 1, else uses the standard, non-parallelised function."""
self.parallel, self._epoch_spectral_connectivity, _ = (
parallel_func(
_epoch_spectral_connectivity, self.n_jobs, verbose=self.verbose
)
)
def _sort_data_info(self):
"""Extracts information stored in the data if it is an Epochs object,
otherwise sets this information to `None`."""
if isinstance(self.data, BaseEpochs):
self.names = self.data.ch_names
self.times_in = self.data.times # input times for Epochs input type
self.sfreq = self.data.info['sfreq']
self.events = self.data.events
self.event_id = self.data.event_id
# Extract metadata from the Epochs data structure.
# Make Annotations persist through by adding them to the metadata.
metadata = self.data.metadata
if metadata is None:
self.annots_in_metadata = False
else:
self.annots_in_metadata = all(
name not in metadata.columns for name in
['annot_onset', 'annot_duration', 'annot_description']
)
if (
hasattr(self.data, 'annotations') and not
self.annots_in_metadata
):
self.data.add_annotations_to_metadata(overwrite=True)
self.metadata = self.data.metadata
else:
self.times_in = None
self.events = None
self.event_id = None
self.metadata = None
def _sort_estimator_inputs(self):
"""Assign names to connectivity methods, check the methods and mode are
recognised, and finds which Granger causality methods are being
called."""
if not isinstance(self.method, (list, tuple)):
self.method = [self.method] # make it a list so we can iterate
self.con_method_types, _, _, _ = _check_estimators(
self.method, self.mode
)
metrics_str = ', '.join([meth.name for meth in self.con_method_types])
logger.info(
' the following metrics will be computed: %s' % metrics_str
)
# find which Granger causality methods are being called
self.present_gc_methods = [
con_method for con_method in self.method
if con_method in self.gc_method_aliases
]
self.remaining_method_types = copy.deepcopy(self.con_method_types)
def _sort_freq_inputs(self):
"""Formats frequency-related inputs and checks they are appropriate."""
if self.fmin is None:
self.fmin = -np.inf # set it to -inf, so we can adjust it later
self.fmin = np.array((self.fmin,), dtype=float).ravel()
self.fmax = np.array((self.fmax,), dtype=float).ravel()
if len(self.fmin) != len(self.fmax):
raise ValueError('fmin and fmax must have the same length')
if np.any(self.fmin > self.fmax):
raise ValueError('fmax must be larger than fmin')
self.n_bands = len(self.fmin)
if self.present_gc_methods:
self._check_for_discontinuous_freqs()
def _check_for_discontinuous_freqs(self):
"""Checks whether the requested frequencies to analyse are
discontinuous (occurs in the case that different frequency bands are
specified in fmin and fmax, but there is a gap between the boundaries of
each frequency band). The state-space GC method used has a
cross-frequency relationship which would be disrupted by computing
connectivity on a discontinuous set of frequencies, so this checks to
see if GC needs to be computed separately on a continuous set of
frequencies spanning from the lowest fmin and highest fmax values which
can then be split into the specified frequency bands.
A simpler check would be to set discontinuous = True if n_bands > 1,
however it could be the case that the bands are continuous, e.g. 8-12 Hz
and 13-20 Hz (with a freq. resolution of 1 Hz), in which case the
frequencies are not discontinuous and GC can be computed alongside other
methods.
"""
n_times = self._get_n_used_times()
# compute frequencies to analyze based on number of samples, sampling
# rate, specified wavelet frequencies and mode
freqs = _compute_freqs(n_times, self.sfreq, self.cwt_freqs, self.mode)
# compute the mask based on specified min/max and decimation factor
freq_mask = _compute_freq_mask(freqs, self.fmin, self.fmax, 0)
# formula for finding if indices of freqs being analysed is continuous;
# array should not contain repeats (but that should always be the
# case for these frequency indices) and should start from 1 (we make
# this adjustment)
use_freqs = np.nonzero(freq_mask)[0]
use_freqs = (use_freqs - min(use_freqs)) + 1 # need to start from 1
if (
sum(np.arange(1, len(use_freqs) + 1)) !=
use_freqs[-1] * (use_freqs[-1] + 1) / 2
):
assert self.n_bands != 1, (
'Frequencies have been detected as discontinuous, yet there is '
'only a single frequency band in the data. Please contact the '
'mne-connectivity developers.'
)
self.discontinuous_freqs = True
def _get_n_used_times(self):
"""Finds and returns the number of timepoints being examined in the
data."""
if self.times_in is None:
if isinstance(self.data, BaseEpochs):
n_times = self.data.get_data().shape[2]
else:
n_times = self.data.shape[2]
times = _arange_div(n_times, self.sfreq)
else:
times = self.times_in
time_mask = _time_mask(times, self.tmin, self.tmax, sfreq=self.sfreq)
tmin_idx, tmax_idx = np.where(time_mask)[0][[0, -1]]
return len(times[tmin_idx : tmax_idx + 1])
def _sort_indices_inputs(self):
"""Checks that the indices are appropriate and sets the number of seeds
and targets in each connection."""
if self.indices is None:
raise ValueError('indices must be specified, got `None`')
if len(self.indices[0]) != len(self.indices[1]):
raise ValueError(
f'the number of seeds ({len(self.indices[0])}) and targets '
f'({len(self.indices[1])}) must match'
)
self.n_cons = len(self.indices[0])
for seeds, targets in zip(self.indices[0], self.indices[1]):
if not isinstance(seeds, list) or not isinstance(targets, list):
raise TypeError(
'seeds and targets for each connection must be given as a '
'list of ints'
)
if (
not all(isinstance(seed, int) for seed in seeds) or
not all(isinstance(target, int) for target in targets)
):
raise TypeError(
'seeds and targets for each connection must be given as a '
'list of ints'
)
if set.intersection(set(seeds), set(targets)):
raise ValueError(
'there are common indices present in the seeds and targets '
'for a single connection, however multivariate '
'connectivity between shared channels is not supported'
)
if isinstance(self.data, BaseEpochs):
n_channels = self.data.get_data().shape[1]
else:
n_channels = self.data.shape[1]
if len(self._get_unique_signals(self.indices)) > n_channels:
raise ValueError(
'the number of unique signals in indices is greater than the '
'number of channels in the data'
)
def _sort_svd_inputs(self):
"""Checks that the SVD parameters are appropriate and finds the correct
dimensionality reduction settings to use, if applicable.
This involves the rank of the data being computed based its non-zero
singular values. We use a cut-off of the largest singular value * 1e-5
by default to determine when a value is non-zero, as using numpy's
default cut-off is too liberal (i.e. low) for our purposes where we need
to be stricter.
"""
self.n_components = copy.copy(self.n_components)
# finds if any SVD has been requested for seeds and/or targets
if self.n_components is None:
self.n_components = (
[None for _ in range(self.n_cons)],
[None for _ in range(self.n_cons)]
)
if self.n_components == 'rank':
self.n_components = (
['rank' for _ in range(self.n_cons)],
['rank' for _ in range(self.n_cons)]
)
if not isinstance(self.n_components, tuple):
raise TypeError('n_components must be a tuple')
for group_i, group_comps in enumerate(self.n_components):
if group_comps is None:
group_comps[group_i] = [None for _ in range(self.n_cons)]
if not isinstance(group_comps, list):
raise TypeError('entries of n_components must be lists')
if len(group_comps) != self.n_cons:
raise ValueError(
'entries of n_components must have the same length as '
'specified the number of connections in indices'
)
if not self.perform_svd and any(
con_comps is not None for con_comps in group_comps
):
self.perform_svd = True
# if SVD is requested, extract the data and perform subsequent checks
if self.perform_svd:
if isinstance(self.data, BaseEpochs):
epochs = self.data.get_data(picks=self.data.ch_names)
else:
epochs = self.data
for group_idcs, group_comps in zip(self.indices, self.n_components):
if any(con_comps is not None for con_comps in group_comps):
index_i = 0
for con_comps, con_chs in zip(group_comps, group_idcs):
if isinstance(con_comps, int):
if con_comps > len(con_chs):
raise ValueError(
'the number of components to take cannot '
'be greater than the number of channels in '
'a given seed/target'
)
if con_comps <= 0:
raise ValueError(
'the number of components to take must be '
'greater than 0'
)
elif isinstance(con_comps, str):
if con_comps != 'rank':
raise ValueError(
'if the number of components is specified '
'as a string, it must be the string "rank"'
)
# compute the rank of the seeds/targets for a con
S = np.linalg.svd(
epochs[:, con_chs, :],
compute_uv=False
)
group_comps[index_i] = min([np.count_nonzero(
s >= s[0]*self.rank_nonzero_tol
) for s in S])
elif con_comps is not None:
raise TypeError(
'n_components must be tuples of lists of '
'`None`, `int`, or the string "rank"'
)
index_i += 1
def compute_csd_and_connectivity(self):
"""Compute the CSD of the data and derive connectivity results from
it."""
# if SVD is requested or the specified fbands are discontinuous, the
# CSD (and hence connectivity) has to be computed separately for any GC
# methods
if self.present_gc_methods and self.compute_gc_separately:
self._compute_separate_gc_csd_and_connectivity()
# if GC has been computed separately, the coherence methods are computed
# here, otherwise both GC and coherence methods are computed here
if self.remaining_method_types:
self._compute_remaining_csd_and_connectivity()
# combine all connectivity results (if GC was computed seperately)
self._collate_connectivity_results()
def _compute_separate_gc_csd_and_connectivity(self):
"""Computes the CSD and connectivity for GC methods separately from
other methods if SVD is being performed, or the requested fbands are
discontinuous.
If SVD is being performed with GC, this has to be done on the
timeseries data for each connection separately, and so this transformed
data cannot be used to compute the CSD for coherence-based connectivity
methods.
Unlike the coherence methods, the state-space GC methods used here rely
on cross-frequency relationships, so discontinuous frequencies will mess
up the results. Hence, GC must be computed on a continuous set of
frequencies, and then have the requested frequency band results taken.
"""
# finds the GC methods to compute
self.separate_gc_method_types = [
mtype for mtype in self.con_method_types if mtype.name in
self.gc_method_names
]
seed_target_data, n_seeds = self._seeds_targets_svd()
# computes GC for each connection separately (no topographies for GC)
n_gc_methods = len(self.present_gc_methods)
self.separate_gc_con = [[] for _ in range(n_gc_methods)]
self.separate_gc_topo = [None for _ in range(n_gc_methods)]
for con_data, n_seed_comps in zip(seed_target_data, n_seeds):
new_indices = (
[np.arange(n_seed_comps).tolist()],
[np.arange(n_seed_comps, con_data.shape[1]).tolist()]
)
con_methods = self._compute_csd(
con_data, self.separate_gc_method_types, new_indices
)
this_con, _, = self._compute_connectivity(con_methods, new_indices)
for method_i in range(n_gc_methods):
self.separate_gc_con[method_i].extend(this_con[method_i])
# finds the methods still needing to be computed
self.remaining_method_types = [
mtype for mtype in self.con_method_types if
mtype not in self.separate_gc_method_types
]
def _seeds_targets_svd(self):
"""SVDs the epoched data separately for the seeds and targets of each
connection according to the specified number of seed and target
components. If the number of components for a given instance is `None`,
the original data is returned."""
if isinstance(self.data, BaseEpochs):
epochs = self.data.get_data(picks=self.data.ch_names).copy()
else:
epochs = self.data.copy()
seed_target_data = []
n_seeds = []
for seeds, targets, n_seed_comps, n_target_comps in zip(
self.indices[0], self.indices[1], self.n_components[0],
self.n_components[1]
):
if n_seed_comps is not None: # SVD seed data
seed_data = self._epochs_svd(epochs[:, seeds, :], n_seed_comps)
else: # use unaltered seed data
seed_data = epochs[:, seeds, :]
n_seeds.append(seed_data.shape[1])
if n_target_comps is not None: # SVD target data
target_data = self._epochs_svd(
epochs[:, targets, :], n_target_comps
)
else: # use unaltered target data
target_data = epochs[:, targets, :]
seed_target_data.append(np.append(seed_data, target_data, axis=1))
return seed_target_data, n_seeds
def _epochs_svd(self, epochs, n_comps):
"""Performs an SVD on epoched data and selects the first k components
for dimensionality reduction before reconstructing the data with
(U_k @ S_k @ V_k)."""
# mean-centre the data epoch-wise
centred_epochs = np.array([epoch - epoch.mean() for epoch in epochs])
# compute the SVD (transposition so that the channels are the columns of
# each epoch)
U, S, V = np.linalg.svd(
centred_epochs.transpose(0, 2, 1), full_matrices=False
)
# take the first k components
U_k = U[:, :, :n_comps]
S_k = np.eye(n_comps) * S[:, np.newaxis][:, :n_comps, :n_comps]
V_k = V[:, :n_comps, :n_comps]
# reconstruct the dimensionality-reduced data (have to transpose the
# data back into [epochs x channels x timepoints])
return (U_k @ (S_k @ V_k)).transpose(0, 2, 1)
def _compute_remaining_csd_and_connectivity(self):
"""Computes connectivity where a single CSD can be computed and the
connectivity computations performed for all connections together (i.e.
anything other than GC with SVD and/or GC with discontinuous
frequencies)."""
con_methods = self._compute_csd(
self.data, self.remaining_method_types, self.indices
)
self.remaining_con, self.remaining_topo = (
self._compute_connectivity(con_methods, self.remapped_indices)
)
def _compute_csd(self, data, con_method_types, indices):
"""Computes the cross-spectral density of the data in preparation for
the multivariate connectivity computations."""
logger.info('Connectivity computation...')
con_methods = self._prepare_csd_computation(
data, con_method_types, indices
)
# performs the CSD computation for each epoch block
logger.info('Computing cross-spectral density from epochs')
self.n_epochs = 0
for epoch_block in ProgressBar(
self.epoch_blocks, mesg='CSD epoch blocks'
):
# check dimensions and time scale
for this_epoch in epoch_block:
_, _, _, self.warn_times = _get_and_verify_data_sizes(
this_epoch, self.sfreq, self.n_signals, self.n_times_in,
self.times_in, warn_times=self.warn_times
)
self.n_epochs += 1
# compute CSD of epochs
epochs = self.parallel(
self._epoch_spectral_connectivity(
data=this_epoch, **self.csd_call_params
)
for this_epoch in epoch_block
)
# unpack and accumulate CSDs of epochs in connectivity methods
for epoch in epochs:
for method, epoch_csd in zip(con_methods, epoch[0]):
method.combine(epoch_csd)
return con_methods
def _prepare_csd_computation(self, data, con_method_types, indices):
"""Collects and returns information in preparation for computing the
cross-spectral density."""
self.epoch_blocks = [
epoch for epoch in _get_n_epochs(data, self.n_jobs)
]
fmin, fmax = self._get_fmin_fmax_for_csd(con_method_types)
(
_, self.times, n_times, self.times_in, self.n_times_in, tmin_idx,
tmax_idx, self.n_freqs, freq_mask, self.freqs, freqs_bands,
freq_idx_bands, self.n_signals, _, self.warn_times
) = _prepare_connectivity(
epoch_block=self.epoch_blocks[0], times_in=self.times_in,
tmin=self.tmin, tmax=self.tmax, fmin=fmin, fmax=fmax,
sfreq=self.sfreq, indices=indices, mode=self.mode,
fskip=self.fskip, n_bands=self.n_bands, cwt_freqs=self.cwt_freqs,
faverage=self.faverage
)
self._store_freq_band_info(
con_method_types, freqs_bands, freq_idx_bands
)
spectral_params, mt_adaptive, self.n_times_spectrum, self.n_tapers = (
_assemble_spectral_params(
mode=self.mode, n_times=n_times, mt_adaptive=self.mt_adaptive,
mt_bandwidth=self.mt_bandwidth, sfreq=self.sfreq,
mt_low_bias=self.mt_low_bias, cwt_n_cycles=self.cwt_n_cycles,
cwt_freqs=self.cwt_freqs, freqs=self.freqs, freq_mask=freq_mask
)
)
self._sort_con_indices(indices)
con_methods = self._instantiate_con_estimators(con_method_types, indices)
self.csd_call_params = dict(
sig_idx=self.sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx,
sfreq=self.sfreq, mode=self.mode, freq_mask=freq_mask,
idx_map=self.idx_map, block_size=self.block_size, psd=None,
accumulate_psd=False, mt_adaptive=mt_adaptive,
con_method_types=self.con_method_types, con_methods=None,
n_signals=self.n_signals, use_n_signals=self.use_n_signals,
n_times=n_times, gc_n_lags=self.gc_n_lags, accumulate_inplace=False
)
self.csd_call_params.update(**spectral_params)
return con_methods
def _get_fmin_fmax_for_csd(self, con_method_types):
"""Gets fmin and fmax args to use for the CSD computation."""
if (
self.present_gc_methods and self.discontinuous_freqs and
con_method_types[0] in self.separate_gc_method_types
):
# compute GC on a continuous set of freqs spanning all bands of
# interest due to the cross-freq relationship of the GC methods
return (
np.array((np.min(self.fmin), )),
np.array((np.max(self.fmax), ))
)
# use existing fmin and fmax if GC is not being computed, or if GC is
# being computed and the requested freq bands are not discontinuous
return (self.fmin, self.fmax)
def _store_freq_band_info(
self, con_method_types, freqs_bands, freq_idx_bands
):
"""Ensures the frequency band information returned from the connectivity
preparation function is correct before storing them in the object."""
if (
self.present_gc_methods and self.discontinuous_freqs and
con_method_types[0] in self.separate_gc_method_types
):
# compute fbands and indices as the freqs appear in fmin and fmax;
# required as the fmin and fmax args to the connectivity preparation
# function differ to those provided by the end user
self.freq_idx_bands = [
np.where((self.freqs >= fl) & (self.freqs <= fu))[0] for
fl, fu in zip(self.fmin, self.fmax)
]
self.freqs_bands = [
self.freqs[freq_idx] for freq_idx in self.freq_idx_bands
]
else:
# use the fband arguments returned from the connectivity preparation
# function, matching the fband args provided by the end user
self.freq_idx_bands = freq_idx_bands
self.freqs_bands = freqs_bands
def _sort_con_indices(self, indices):
"""Maps indices to the unique indices, finds the signals for which the
CSD needs to be computed (and how many used signals there are), and gets
the seed-target indices for the CSD."""
# map indices to unique indices
unique_indices = np.unique(np.concatenate(sum(indices, [])))
remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(unique_indices)}
self.remapped_indices = tuple([
[[remapping[idx] for idx in idcs] for idcs in indices_group]
for indices_group in indices
])
# unique signals for which we actually need to compute CSD
self.sig_idx = self._get_unique_signals(self.remapped_indices)
self.use_n_signals = len(self.sig_idx)
# gets seed-target indices for CSD
self.idx_map = [
np.repeat(
self.sig_idx, len(self.sig_idx)),
np.tile(self.sig_idx, len(self.sig_idx)
)
]
def _get_unique_signals(self, indices):
"""Find the unique signals in a set of indices."""
return np.unique(sum(sum(indices, []), []))
def _instantiate_con_estimators(self, con_method_types, indices):
"""Create instances of the connectivity estimators and log the methods
being computed."""
con_methods = []
for mtype in con_method_types:
if "n_lags" in list(inspect.signature(mtype).parameters):
# if a GC method, provide n_lags argument
con_methods.append(
mtype(
self.use_n_signals, len(indices[0]), self.n_freqs,
self.n_times_spectrum, self.gc_n_lags, self.n_jobs
)
)
else:
# if a coherence method, do not provide n_lags argument
con_methods.append(
mtype(
self.use_n_signals, len(indices[0]), self.n_freqs,
self.n_times_spectrum, self.n_jobs
)
)
return con_methods
def _compute_connectivity(self, con_methods, indices):
"""Computes the multivariate connectivity results."""
con = [None for _ in range(len(con_methods))]
topo = [None for _ in range(len(con_methods))]
# add the GC results to con in the correct positions according to the
# order of con_methods
con = self._compute_gc_connectivity(con_methods, con, indices)
# add the coherence results to con in the correct positions according to
# the order of con_methods
con, topo = self._compute_coh_connectivity(
con_methods, con, topo, indices
)
method_i = 0
for method_con, method_topo in zip(con, topo):
assert method_con is not None, (
'A connectivity result has been missed. Please contact the '
'mne-connectivity developers.'
)
self._check_correct_results_dimensions(
con_methods, method_con, method_topo, indices
)
if self.faverage:
con[method_i], topo[method_i] = self._compute_faverage(
con=method_con, topo=method_topo
)
method_i += 1
self.freqs_used = self.freqs
if self.faverage:
# for each band we return the frequencies that were averaged
self.freqs = [np.mean(band) for band in self.freqs_bands]
# return max and min frequencies that were averaged for each band
self.freqs_used = [
[np.min(band), np.max(band)] for band in self.freqs_bands
]
# number of nodes in the original data
self.n_nodes = self.n_signals
return con, topo
def _compute_gc_connectivity(self, con_methods, con, indices):
"""Computes GC connectivity.
Different GC methods can rely on common information, so rather than
re-computing this information everytime a different GC method is called,
store this information such that it can be accessed to compute the final
GC connectivity scores when needed.
"""
self._get_gc_forms_to_compute(con_methods)
if self.compute_gc_forms:
self._compute_and_set_gc_autocov()
gc_scores = {}
for form_name, form_info in self.compute_gc_forms.items():
# computes connectivity for individual GC forms
form_info['method_class'].compute_con(
indices[0], indices[1], form_info['flip_seeds_targets'],
form_info['reverse_time'], form_name
)
# assigns connectivity score to their appropriate GC forms for
# combining into the final GC method results
gc_scores[form_name] = form_info['method_class'].con_scores
con = self._combine_gc_forms(con_methods, con, gc_scores)
# remove the results for frequencies not requested by the end user
if self.discontinuous_freqs:
con = self._make_gc_freqs_discontinuous(con)
# set n_signals to equal the number in the non-SVD data
if self.perform_svd:
self.n_signals = len(self._get_unique_signals(self.indices))
return con
def _get_gc_forms_to_compute(self, con_methods):
"""Finds the GC forms that need to be computed."""
self.compute_gc_forms = {}
for form_name, form_info in self.possible_gc_forms.items():
for method in con_methods:
if (
method.name in form_info['for_methods'] and
form_name not in self.compute_gc_forms.keys()
):
form_info.update(method_class=copy.deepcopy(method))
self.compute_gc_forms[form_name] = form_info
def _compute_and_set_gc_autocov(self):
"""Computes autocovariance once and assigns it to all GC methods."""
first_form = True
for form_info in self.compute_gc_forms.values():
if first_form:
form_info['method_class'].compute_autocov(self.n_epochs)
autocov = form_info['method_class'].autocov.copy()
first_form = False
else:
form_info['method_class'].autocov = autocov
def _combine_gc_forms(self, con_methods, con, gc_scores):
"""Combines the information from all the different GC forms so that the
final connectivity scores for the requested GC methods are returned."""
for method_i, method in enumerate(con_methods):
if method.name == 'GC':
con[method_i] = gc_scores['seeds -> targets']
elif method.name == 'Net GC':
con[method_i] = (
gc_scores['seeds -> targets'] -
gc_scores['targets -> seeds']
)
elif method.name == 'TRGC':
con[method_i] = (
gc_scores['seeds -> targets'] -
gc_scores['time-reversed[seeds -> targets]']
)
elif method.name == 'Net TRGC':
con[method_i] = (
(
gc_scores['seeds -> targets'] -
gc_scores['targets -> seeds']
) - (
gc_scores['time-reversed[seeds -> targets]'] -
gc_scores['time-reversed[targets -> seeds]']
)
)
return con
def _make_gc_freqs_discontinuous(self, con):
"""Remove the unrequested frequencies from the GC results so that the
results match the frequency bands requested by the end user."""
# find which freqs in the results are needed
requested_freqs = np.concatenate(self.freq_idx_bands)
freq_mask = [freq in requested_freqs for freq in range(self.n_freqs)]
# exclude the unwanted freqs from the results
for method_i, method_con in enumerate(con):
con[method_i] = method_con[:, freq_mask]
# set the frequency attrs to the correct, discontinuous values
self.n_freqs = len(requested_freqs)
self.freqs = self.freqs[freq_mask]
freq_idx_bands = []
freq_idx = 0
for band in self.freq_idx_bands:
freq_idx_bands.append(
np.arange(freq_idx, freq_idx + len(band), dtype=band.dtype)
)
freq_idx += len(band)
self.freq_idx_bands = freq_idx_bands
return con
def _compute_coh_connectivity(self, con_methods, con, topo, indices):
"""Computes MIC and MIM connectivity.
MIC and MIM rely on common information, so rather than re-computing this
information everytime a different coherence method is called, store this
information such that it can be accessed to compute the final MIC and
MIM connectivity scores when needed.
"""
self._get_coh_form_to_compute(con_methods)
if self.compute_coh_form:
# compute connectivity for MIC and/or MIM in a single instance
form_name = list(self.compute_coh_form.keys())[0] # only one there
form_info = self.compute_coh_form[form_name]
form_info['method_class'].compute_con(
indices[0], indices[1], self.n_components, self.n_epochs,
form_name
)
# store the MIC and/or MIM results in the right places
for method_i, method in enumerate(con_methods):
if method.name == "MIC":
con[method_i] = form_info['method_class'].mic_scores
topo[method_i] = form_info['method_class'].topographies
elif method.name == "MIM":
con[method_i] = form_info['method_class'].mim_scores
return con, topo
def _get_coh_form_to_compute(self, con_methods):
"""Finds the coherence form that need to be computed."""
method_names = [method.name for method in con_methods]
self.compute_coh_form = {}
for form_name, form_info in self.possible_coh_forms.items():
if (
all(name in method_names for name in form_info['for_methods'])
and not any(
name in method_names for name in
form_info['exclude_methods']
)
):
coh_class = con_methods[
method_names.index(form_info['for_methods'][0])
]
form_info.update(method_class=coh_class)
self.compute_coh_form[form_name] = form_info
break # only one form is possible at any one instance
def _check_correct_results_dimensions(self, con_methods, con, topo, indices):
"""Checks that the results of the connectivity computations have the
appropriate dimensions."""
n_cons = len(indices[0])
n_times = con_methods[0].n_times
assert (con.shape[0] == n_cons), (
'The first dimension of connectivity scores does not match the '
'number of connections. Please contact the mne-connectivity '
'developers.'
)
assert (con.shape[1] == self.n_freqs), (
'The second dimension of connectivity scores does not match the '
'number of frequencies. Please contact the mne-connectivity '
'developers.'
)
if n_times != 0:
assert (con.shape[2] == n_times), (
'The third dimension of connectivity scores does not match '
'the number of timepoints. Please contact the mne-connectivity '
'developers.'
)
if topo is not None:
assert (topo[0].shape[0] == n_cons and topo[1].shape[0]), (
'The first dimension of topographies does not match the number '
'of connections. Please contact the mne-connectivity '
'developers.'
)
for con_i in range(n_cons):
assert (
topo[0][con_i].shape[1] == self.n_freqs and
topo[1][con_i].shape[1] == self.n_freqs
), (
'The second dimension of topographies does not match the '
'number of frequencies. Please contact the '
'mne-connectivity developers.'
)
if n_times != 0:
assert (
topo[0][con_i].shape[2] == n_times and
topo[1][con_i].shape[2] == n_times
), (
'The third dimension of topographies does not match '
'the number of timepoints. Please contact the '
'mne-connectivity developers.'
)
def _compute_faverage(self, con, topo):
"""Computes the average connectivity across the frequency bands."""
n_cons = con.shape[0]
con_shape = (n_cons, self.n_bands) + con.shape[2:]
con_bands = np.empty(con_shape, dtype=con.dtype)
for band_idx in range(self.n_bands):
con_bands[:, band_idx] = np.mean(
con[:, self.freq_idx_bands[band_idx]], axis=1
)
if topo is not None:
topo_bands = np.empty((2, n_cons), dtype=topo.dtype)
for group_i in range(2):
for con_i in range(n_cons):
band_topo = [
np.mean(topo[group_i][con_i][:, freq_idx_band], axis=1)
for freq_idx_band in self.freq_idx_bands
]
topo_bands[group_i][con_i] = np.array(band_topo).T
else:
topo_bands = None
return con_bands, topo_bands
def _collate_connectivity_results(self):
"""Collects the connectivity results for non-GC with SVD analysis and GC
with SVD together according to the order in which the respective methods
were called."""
self.con = [*self.remaining_con, *self.separate_gc_con]
self.topo = [*self.remaining_topo, *self.separate_gc_topo]
if self.remaining_con and self.separate_gc_con: