Skip to content

Commit

Permalink
Merge pull request #4 from agramfort/pr/177
Browse files Browse the repository at this point in the history
review connectivity
  • Loading branch information
Martin Luessi committed Nov 29, 2012
2 parents 0509f65 + ded3ca6 commit 29ea838
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 71 deletions.
4 changes: 2 additions & 2 deletions examples/connectivity/plot_cwt_sensor_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@

# Define wavelet frequencies and number of cycles
cwt_frequencies = np.arange(7, 30, 2)
cwt_n_cycles = cwt_frequencies / float(7)
cwt_n_cycles = cwt_frequencies / 7.

# Run the connectivity analysis using 2 parallel jobs
sfreq = raw.info['sfreq'] # the sampling frequency
con, freqs, times, _, _ = spectral_connectivity(epochs, indices=indices,
method='wpli2_debiased', spectral_mode='cwt_morlet', sfreq=sfreq,
method='wpli2_debiased', mode='cwt_morlet', sfreq=sfreq,
cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles, n_jobs=2)

# Mark the seed channel with a value of 1.0, so we can see it in the plot
Expand Down
14 changes: 9 additions & 5 deletions examples/connectivity/plot_mne_inverse_coherence_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@
sfreq = raw.info['sfreq'] # the sampling frequency

# Now we compute connectivity. To speed things up, we use 2 parallel jobs
# and use spectral_mode='fourier', which uses a FFT with a Hanning window
# and use mode='fourier', which uses a FFT with a Hanning window
# to compute the spectra (instead of multitaper estimation, which has a
# lower variance but is slower). By using faverage=True, we directly
# average the coherence in the alpha and beta band, i.e., we will only
# get 2 frequency bins
coh, freqs, times, n_epochs, n_tapers = spectral_connectivity(stcs,
method='coh', spectral_mode='fourier', indices=indices,
method='coh', mode='fourier', indices=indices,
sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True, mt_adaptive=False,
n_jobs=2)

Expand All @@ -120,14 +120,18 @@
# save the cohrence to plot later
aud_rh_coh[band] = np.mean(coh_stc.label_stc(label_rh).data, axis=0)

# We could save the coherence, for visualization using e.g. mne_analyze
#coh_stc.save('seed_coh_%s_vertno_%d' % (band, seed_vertno))
# Save the coherence for visualization using e.g. mne_analyze
coh_stc.save('seed_coh_%s_vertno_%d' % (band, seed_vertno))

# XXX : I would save only one stc containing all the bands so it's easy
# to visualize in mne_analyze. Otherwise you have to switch between stcs
# to see how it differs between bands.

pl.figure()
width = 0.5
pos = np.arange(2) + 0.25
pl.bar(pos, [aud_rh_coh['alpha'], aud_rh_coh['beta']], width)
pl.ylabel('Coherence')
pl.title('Cohrence left-right auditory')
pl.title('Coherence left-right auditory')
pl.xticks(pos + width / 2, ('alpha', 'beta'))
pl.show()
2 changes: 1 addition & 1 deletion examples/connectivity/plot_sensor_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
sfreq = raw.info['sfreq'] # the sampling frequency
tmin = 0.0 # exclude the baseline period
con, freqs, times, n_epochs, n_tapers = spectral_connectivity(epochs,
method='pli', spectral_mode='multitaper', sfreq=sfreq,
method='pli', mode='multitaper', sfreq=sfreq,
fmin=fmin, fmax=fmax, faverage=True, tmin=tmin,
mt_adaptive=False, n_jobs=2)

Expand Down
101 changes: 50 additions & 51 deletions mne/connectivity/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy):
if self.con_scores is None:
self.con_scores = np.zeros(self.csd_shape)
csd_mean = self._acc[con_idx] / n_epochs
self.con_scores[con_idx] = np.abs(csd_mean)\
/ np.sqrt(psd_xx * psd_yy)
self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy)


class _CohyEst(_CohEstBase):
Expand All @@ -101,8 +100,7 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy):
self.con_scores = np.zeros(self.csd_shape,
dtype=np.complex128)
csd_mean = self._acc[con_idx] / n_epochs
self.con_scores[con_idx] = csd_mean\
/ np.sqrt(psd_xx * psd_yy)
self.con_scores[con_idx] = csd_mean / np.sqrt(psd_xx * psd_yy)


class _ImCohEst(_CohEstBase):
Expand All @@ -114,8 +112,7 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy):
if self.con_scores is None:
self.con_scores = np.zeros(self.csd_shape)
csd_mean = self._acc[con_idx] / n_epochs
self.con_scores[con_idx] = np.imag(csd_mean)\
/ np.sqrt(psd_xx * psd_yy)
self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy)


class _PLVEst(_EpochMeanConEstBase):
Expand Down Expand Up @@ -192,17 +189,17 @@ def __init__(self, n_cons, n_freqs, n_times):
def accumulate(self, con_idx, csd_xy):
"""Accumulate some connections"""
im_csd = np.imag(csd_xy)
self._acc[con_idx, Ellipsis, 0] += im_csd
self._acc[con_idx, Ellipsis, 1] += np.abs(im_csd)
self._acc[con_idx, ..., 0] += im_csd
self._acc[con_idx, ..., 1] += np.abs(im_csd)

def compute_con(self, con_idx, n_epochs):
"""Compute final con. score for some connections"""
if self.con_scores is None:
self.con_scores = np.zeros(self.csd_shape)

acc_mean = self._acc[con_idx] / n_epochs
num = np.abs(acc_mean[:, Ellipsis, 0])
denom = acc_mean[:, Ellipsis, 1]
num = np.abs(acc_mean[:, ..., 0])
denom = acc_mean[:, ..., 1]

# handle zeros in denominator
z_denom = np.where(denom == 0.)
Expand All @@ -229,9 +226,9 @@ def __init__(self, n_cons, n_freqs, n_times):
def accumulate(self, con_idx, csd_xy):
"""Accumulate some connections"""
im_csd = np.imag(csd_xy)
self._acc[con_idx, Ellipsis, 0] += im_csd
self._acc[con_idx, Ellipsis, 1] += np.abs(im_csd)
self._acc[con_idx, Ellipsis, 2] += im_csd ** 2
self._acc[con_idx, ..., 0] += im_csd
self._acc[con_idx, ..., 1] += np.abs(im_csd)
self._acc[con_idx, ..., 2] += im_csd ** 2

def compute_con(self, con_idx, n_epochs):
"""Compute final con. score for some connections"""
Expand All @@ -240,18 +237,17 @@ def compute_con(self, con_idx, n_epochs):

# note: we use the trick from fieldtrip to compute the
# the estimate over all pairwise epoch combinations
sum_im_csd = self._acc[con_idx, Ellipsis, 0]
sum_abs_im_csd = self._acc[con_idx, Ellipsis, 1]
sum_sq_im_csd = self._acc[con_idx, Ellipsis, 2]
sum_im_csd = self._acc[con_idx, ..., 0]
sum_abs_im_csd = self._acc[con_idx, ..., 1]
sum_sq_im_csd = self._acc[con_idx, ..., 2]

denom = (sum_abs_im_csd ** 2 - sum_sq_im_csd)
denom = sum_abs_im_csd ** 2 - sum_sq_im_csd

# handle zeros in denominator
z_denom = np.where(denom == 0.)
denom[z_denom] = 1.

con = (sum_im_csd ** 2 - sum_sq_im_csd)\
/ denom
con = (sum_im_csd ** 2 - sum_sq_im_csd) / denom

# where we had zeros in denominator, we set con to zero
con[z_denom] = 0.
Expand Down Expand Up @@ -288,20 +284,22 @@ def compute_con(self, con_idx, n_epochs):

# note: we use the trick from fieldtrip to compute the
# the estimate over all pairwise epoch combinations
con = (self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs)\
/ (n_epochs * (n_epochs - 1))
con = ((self._acc[con_idx] * np.conj(self._acc[con_idx]) - n_epochs)
/ (n_epochs * (n_epochs - 1.)))

self.con_scores[con_idx] = np.real(con)


########################################################################
###############################################################################


def _epoch_spectral_connectivity(data, sfreq, spectral_mode, window_fun,
def _epoch_spectral_connectivity(data, sfreq, mode, window_fun,
eigvals, wavelets, freq_mask, mt_adaptive,
idx_map, block_size, psd, accumulate_psd,
con_method_types, con_methods,
accumulate_inplace=True):
"""Connectivity estimation for one epoch see spectral_connectivity"""

n_cons = len(idx_map[0])

if wavelets is not None:
Expand All @@ -311,14 +309,13 @@ def _epoch_spectral_connectivity(data, sfreq, spectral_mode, window_fun,
n_times_spectrum = 0
n_freqs = np.sum(freq_mask)

"""Connectivity estimation for one epoch see spectral_connectivity"""
if not accumulate_inplace:
# instantiate methods only for this epoch (used in parallel mode)
con_methods = [mtype(n_cons, n_freqs, n_times_spectrum)
for mtype in con_method_types]

# compute tapered spectra
if spectral_mode in ['multitaper', 'fourier']:
if mode in ['multitaper', 'fourier']:
x_mt, _ = _mt_spectra(data, window_fun, sfreq)

if mt_adaptive:
Expand All @@ -331,22 +328,22 @@ def _epoch_spectral_connectivity(data, sfreq, spectral_mode, window_fun,
else:
# do not use adaptive weights
x_mt = x_mt[:, :, freq_mask]
if spectral_mode == 'multitaper':
if mode == 'multitaper':
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]
else:
# hack to so we can sum over axis=-2
weights = np.array([1.])[:, None, None]

if accumulate_psd:
this_psd = _psd_from_mt(x_mt, weights)
elif spectral_mode == 'cwt_morlet':
elif mode == 'cwt_morlet':
# estimate spectra using CWT
x_cwt = cwt(data, wavelets, use_fft=True, mode='same')

if accumulate_psd:
this_psd = np.abs(x_cwt) ** 2
else:
raise RuntimeError('invalid spectral_mode')
raise RuntimeError('invalid mode')

# accumulate or return psd
if accumulate_psd:
Expand All @@ -362,7 +359,7 @@ def _epoch_spectral_connectivity(data, sfreq, spectral_mode, window_fun,
method.start_epoch()

# accumulate connectivity scores
if spectral_mode in ['multitaper', 'fourier']:
if mode in ['multitaper', 'fourier']:
for i in xrange(0, n_cons, block_size):
con_idx = slice(i, i + block_size)
if mt_adaptive:
Expand Down Expand Up @@ -423,7 +420,7 @@ def _check_method(method):

@verbose
def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
spectral_mode='multitaper', fmin=None, fmax=np.inf,
mode='multitaper', fmin=None, fmax=np.inf,
fskip=0, faverage=False, tmin=None, tmax=None,
mt_bandwidth=None, mt_adaptive=False,
mt_low_bias=True, cwt_frequencies=None,
Expand All @@ -436,11 +433,11 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
All methods are based on estimates of the cross- and power spectral
densities (CSD/PSD) Sxy and Sxx, Syy.
The spectral densities can be estimated using a multi taper method with
The spectral densities can be estimated using a multitaper method with
digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier
transform with Hanning windows, or a continuous wavelet transform using
Morlet wavelets. The spectral estimation mode is specified using the
"spectral_mode" parameter.
"mode" parameter.
By default, the connectivity between all signals is computed (only
connections corresponding to the lower-triangular part of the
Expand Down Expand Up @@ -532,13 +529,14 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
connectivity. If None, all connections are computed.
sfreq : float
The sampling frequency.
spectral_mode : str
mode : str
Spectrum estimation mode can be either: 'multitaper', 'fourier', or
'cwt_morlet'.
fmin : float | tuple of floats
The lower frequency of interest. Multiple bands are defined using
a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq.
By default, the frequency corresponing to 5 cycles is used.
If None the frequency corresponding to an epoch length of 5 cycles
is used.
fmax : float | tuple of floats
The upper frequency of interest. Multiple bands are dedined using
a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq.
Expand All @@ -554,12 +552,14 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
tmax : float | None
Time to end connectivity estimation.
mt_bandwidth : float
The bandwidth of the multi taper windowing function in Hz.
The bandwidth of the multitaper windowing function in Hz.
Only used in 'multitaper' mode.
mt_adaptive : bool
Use adaptive weights to combine the tapered spectra into PSD.
Only used in 'multitaper' mode.
mt_low_bias : bool
Only use tapers with more than 90% spectral concentration within
bandwidth.
bandwidth. Only used in 'multitaper' mode.
cwt_frequencies : array
Array of frequencies of interest. Only used in 'cwt_morlet' mode.
cwt_n_cycles: float | array of float
Expand Down Expand Up @@ -590,9 +590,9 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
n_epochs : int
Number of epochs used for computation.
n_tapers : int
The number of DPSS tapers used.
The number of DPSS tapers used. Only defined in 'multitaper' mode.
Otherwise is returned None.
"""

if n_jobs > 1:
parallel, my_epoch_spectral_connectivity, _ = \
parallel_func(_epoch_spectral_connectivity, n_jobs,
Expand Down Expand Up @@ -702,12 +702,12 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
% (tmin_true, tmax_true, n_times))

# get frequencies of interest for the different modes
if spectral_mode in ['multitaper', 'fourier']:
if mode in ['multitaper', 'fourier']:
# fmin fmax etc is only supported for these modes
# decide which frequencies to keep
freqs_all = fftfreq(n_times, 1. / sfreq)
freqs_all = freqs_all[freqs_all >= 0]
elif spectral_mode == 'cwt_morlet':
elif mode == 'cwt_morlet':
# cwt_morlet mode
if cwt_frequencies is None:
raise ValueError('define frequencies of interest using '
Expand All @@ -717,10 +717,10 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
'larger than Nyquist (sfreq / 2)')
freqs_all = cwt_frequencies
else:
raise ValueError('spectral_mode has an invalid value')
raise ValueError('mode has an invalid value')

# check that fmin corresponds to at least 5 cycles
five_cycle_freq = 5 * sfreq / float(n_times)
five_cycle_freq = 5. * sfreq / float(n_times)

if any(np.isnan(fmin)):
if len(fmin) > 1:
Expand All @@ -735,8 +735,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
# create a frequency mask for all bands
freq_mask = np.zeros(len(freqs_all), dtype=np.bool)
for f_lower, f_upper in zip(fmin, fmax):
freq_mask |= ((freqs_all >= f_lower)
& (freqs_all <= f_upper))
freq_mask |= ((freqs_all >= f_lower) & (freqs_all <= f_upper))

# possibly skip frequency points
for pos in xrange(fskip):
Expand Down Expand Up @@ -767,7 +766,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
'each band')

# get the window function, wavelets, etc for different modes
if spectral_mode == 'multitaper':
if mode == 'multitaper':
# compute standardized half-bandwidth
if mt_bandwidth is not None:
half_nbw = float(mt_bandwidth) * n_times / (2 * sfreq)
Expand All @@ -790,7 +789,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,

n_times_spectrum = 0 # this method only uses the freq. domain
wavelets = None
elif spectral_mode == 'fourier':
elif mode == 'fourier':
logger.info(' using FFT with a Hanning window to estimate '
'spectra')

Expand All @@ -800,7 +799,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
n_tapers = None
n_times_spectrum = 0 # this method only uses the freq. domain
wavelets = None
elif spectral_mode == 'cwt_morlet':
elif mode == 'cwt_morlet':
logger.info(' using CWT with Morlet wavelets to estimate '
'spectra')

Expand All @@ -815,13 +814,13 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,

# get the Morlet wavelets
wavelets = morlet(sfreq, freqs,
n_cycles=cwt_n_cycles, zero_mean=False)
n_cycles=cwt_n_cycles, zero_mean=True)
eigvals = None
n_tapers = None
window_fun = None
n_times_spectrum = n_times
else:
raise ValueError('spectral_mode has an invalid value')
raise ValueError('mode has an invalid value')

# unique signals for which we actually need to compute PSD etc.
sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]])
Expand Down Expand Up @@ -868,7 +867,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,

# con methods and psd are updated inplace
_epoch_spectral_connectivity(this_epoch[sig_idx], sfreq,
spectral_mode, window_fun, eigvals, wavelets, freq_mask,
mode, window_fun, eigvals, wavelets, freq_mask,
mt_adaptive, idx_map, block_size, psd, accumulate_psd,
con_method_types, con_methods, accumulate_inplace=True)
epoch_idx += 1
Expand All @@ -878,7 +877,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
% (epoch_idx + 1, epoch_idx + len(epoch_block)))

out = parallel(my_epoch_spectral_connectivity(this_epoch[sig_idx],
sfreq, spectral_mode, window_fun, eigvals, wavelets,
sfreq, mode, window_fun, eigvals, wavelets,
freq_mask, mt_adaptive, idx_map, block_size, psd,
accumulate_psd, con_method_types, None,
accumulate_inplace=False) for this_epoch in epoch_block)
Expand Down
Loading

0 comments on commit 29ea838

Please sign in to comment.