Skip to content

Commit

Permalink
[Fix] NeuralEnsemble#613 spike_train_generation module to handle mult…
Browse files Browse the repository at this point in the history
…ichannel AnalogSignal inputs (NeuralEnsemble#614)

* fix docstring add type annotations
* fix input checks peak detection
* add tests for peak_extraction
* add handling of multichannel analogsignals to peak detection
  • Loading branch information
Moritz-Alexander-Kern authored Jul 17, 2024
1 parent bdd98ee commit 32e1199
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 104 deletions.
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@
intersphinx_mapping = {
'viziphant': ('https://viziphant.readthedocs.io/en/stable/', None),
'numpy': ('https://numpy.org/doc/stable', None),
'neo': ('https://neo.readthedocs.io/en/stable/', None),
'neo': ('https://neo.readthedocs.io/en/latest/', None),
'quantities': ('https://python-quantities.readthedocs.io/en/stable/', None),
'python': ('https://docs.python.org/3/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/', None)
Expand Down
286 changes: 211 additions & 75 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@
from __future__ import division, print_function, unicode_literals

import warnings
from typing import List, Union, Optional
from typing import List, Literal, Union, Optional

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from scipy import stats
Expand Down Expand Up @@ -83,53 +84,21 @@
]


def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',
time_stamps=None, interval=(-2 * pq.ms, 4 * pq.ms)):
"""
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.
Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : pq.Quantity, optional
If `spike_train` is a `pq.Quantity` array, `time_stamps` provides the
time stamps around which the waveform is extracted. If it is None, the
function `peak_detection` is used to calculate the time_stamps
from signal.
Default: None
interval : tuple of pq.Quantity
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)
Returns
-------
result_st : neo.SpikeTrain
Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.
See Also
--------
elephant.spike_train_generation.peak_detection
"""
def _spike_extraction_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms)
) -> neo.core.SpikeTrain:
# Get spike time_stamps
if time_stamps is None:
time_stamps = peak_detection(signal, threshold, sign=sign)
elif hasattr(time_stamps, 'times'):
time_stamps = time_stamps.times
elif isinstance(time_stamps, pq.Quantity):
raise TypeError("time_stamps must be None, a pq.Quantity array or" +
" expose the.times interface")
else:
raise TypeError("time_stamps must be None, a `neo.core.SpikeTrain`"
" or expose the.times interface")

if len(time_stamps) == 0:
return neo.SpikeTrain(time_stamps, units=signal.times.units,
Expand All @@ -139,6 +108,7 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',

# Unpack the extraction interval from tuple or array
extr_left, extr_right = interval

if extr_left > extr_right:
raise ValueError("interval[0] must be < interval[1]")

Expand Down Expand Up @@ -185,36 +155,90 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',
left_sweep=extr_left)


def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'):
def spike_extraction(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms),
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.
Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
signal : :class:`neo.core.AnalogSignal`
An analog input signal one or more channels.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : :class:`neo.core.SpikeTrain` , optional
Provides the time stamps around which the waveform is extracted. If it
is None, the function `peak_detection` is used to calculate the
`time_stamps` from signal.
Default: None
interval : tuple of :class:`pq.Quantity`
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)
always_as_list: bool, optional
If True, :class:`neo.core.spiketrainslist.SpikeTrainList` is returned.
Default: False
Returns
-------
result_st : neo.SpikeTrain
Contains the spike times of each of the events (spikes) extracted from
the signal.
"""
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`.
Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.
if not isinstance(threshold, pq.Quantity):
raise ValueError('threshold must be a pq.Quantity')
See Also
--------
:func:`elephant.spike_train_generation.peak_detection`
"""
if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return SpikeTrainList(items=(
_spike_extraction_from_single_channel(
signal,
threshold=threshold,
time_stamps=time_stamps,
interval=interval,
sign=sign),))
else:
return _spike_extraction_from_single_channel(
signal, threshold=threshold, time_stamps=time_stamps,
interval=interval, sign=sign)
elif signal.shape[1] > 1:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(
_spike_extraction_from_single_channel(
neo.core.AnalogSignal(
signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign,
time_stamps=time_stamps,
interval=interval,
))
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")

if sign not in ('above', 'below'):
raise ValueError("sign should be 'above' or 'below'")

def _threshold_detection_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above'
) -> neo.core.SpikeTrain:
if sign == 'above':
cutout = np.where(signal > threshold)[0]
else:
Expand Down Expand Up @@ -242,53 +266,88 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'):
return result_st


def peak_detection(signal, threshold=0.0 * pq.mV, sign='above',
as_array=False):
def threshold_detection(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
always_as_list: bool = False,
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Return the peak times for all events that cross threshold.
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
threshold : pq.Quantity, optional
signal : :class:`neo.core.AnalogSignal`
An analog input signal with one or multiple channels.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Default: 0.*pq.mV
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
as_array : bool, optional
If True, a NumPy array of the resulting peak times is returned instead
of a (default) `neo.SpikeTrain` object.
always_as_list: bool, optional
If True, a :class:`neo.core.spiketrainslist.SpikeTrainList`.
Default: False
Returns
-------
result_st : neo.SpikeTrain
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`
Contains the spike times of each of the events (spikes) extracted from
the signal.
the signal. If `signal` is an AnalogSignal with multiple channels, or
`always_return_list=True` , a
:class:`neo.core.spiketrainlist.SpikeTrainList` is returned.
"""
if not isinstance(threshold, pq.Quantity):
raise ValueError("threshold must be a pq.Quantity")
raise TypeError('threshold must be a pq.Quantity')

if sign not in ('above', 'below'):
raise ValueError("sign should be 'above' or 'below'")

if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return SpikeTrainList(items=(
_threshold_detection_from_single_channel(
signal, threshold=threshold, sign=sign),))
else:
return _threshold_detection_from_single_channel(
signal, threshold=threshold, sign=sign)
elif signal.shape[1] > 1:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(_threshold_detection_from_single_channel(
neo.core.AnalogSignal(signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign)
)
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


# legacy implementation of peak_detection
def _peak_detection_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above',
as_array: bool = False
) -> neo.core.SpikeTrain:
if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
else:
# sign == 'below'
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign should be 'above' or 'below'")

if len(cutout) == 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# Select the crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
Expand Down Expand Up @@ -327,6 +386,83 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above',
return result_st


def peak_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
as_array: bool = False,
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : :class:`neo.core.AnalogSignal`
An analog input signal or a list of analog input signals.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Default: 0.*pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
as_array : bool, optional
If True, a NumPy array of the resulting peak times is returned instead
of a (default) `neo.SpikeTrain` object.
Default: False
always_as_list: bool, optional
If True, a :class:`neo.core.spiketrainslist.SpikeTrainList` is returned.
Default: False
Returns
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`
:class:`np.ndarrav`, List[:class:`np.ndarrav`]
Contains the spike times of each of the events (spikes) extracted from
the signal.
If `signal` is an AnalogSignal with multiple channels or
`always_return_list=True` a list is returned.
"""
if not isinstance(threshold, pq.Quantity):
raise TypeError(
f"threshold must be a pq.Quantity, provided: {type(threshold)}")

if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list and not as_array:
return SpikeTrainList(items=(
_peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign,
as_array=as_array),))
elif always_as_list and as_array:
return [_peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign, as_array=as_array)]
else:
return _peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign, as_array=as_array)
elif signal.shape[1] > 1 and as_array:
return [_peak_detection_from_single_channel(neo.core.AnalogSignal(
signal[:, channel], sampling_rate=signal.sampling_rate),
threshold=threshold,
sign=sign, as_array=as_array
) for channel in range(signal.shape[1])]
elif signal.shape[1] > 1 and not as_array:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(_peak_detection_from_single_channel(
neo.core.AnalogSignal(signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold,
sign=sign, as_array=as_array
))
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


class AbstractPointProcess:
"""
Abstract point process to subclass from.
Expand Down
Loading

0 comments on commit 32e1199

Please sign in to comment.