Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] #613 spike_train_generation module to handle multichannel AnalogSignal inputs #614

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@
intersphinx_mapping = {
'viziphant': ('https://viziphant.readthedocs.io/en/stable/', None),
'numpy': ('https://numpy.org/doc/stable', None),
'neo': ('https://neo.readthedocs.io/en/stable/', None),
'neo': ('https://neo.readthedocs.io/en/latest/', None),
'quantities': ('https://python-quantities.readthedocs.io/en/stable/', None),
'python': ('https://docs.python.org/3/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/', None)
Expand Down
286 changes: 211 additions & 75 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@
from __future__ import division, print_function, unicode_literals

import warnings
from typing import List, Union, Optional
from typing import List, Literal, Union, Optional

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from scipy import stats
Expand Down Expand Up @@ -83,53 +84,21 @@
]


def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',
time_stamps=None, interval=(-2 * pq.ms, 4 * pq.ms)):
"""
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.

Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : pq.Quantity, optional
If `spike_train` is a `pq.Quantity` array, `time_stamps` provides the
time stamps around which the waveform is extracted. If it is None, the
function `peak_detection` is used to calculate the time_stamps
from signal.
Default: None
interval : tuple of pq.Quantity
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)

Returns
-------
result_st : neo.SpikeTrain
Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.

See Also
--------
elephant.spike_train_generation.peak_detection
"""
def _spike_extraction_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms)
) -> neo.core.SpikeTrain:
# Get spike time_stamps
if time_stamps is None:
time_stamps = peak_detection(signal, threshold, sign=sign)
elif hasattr(time_stamps, 'times'):
time_stamps = time_stamps.times
elif isinstance(time_stamps, pq.Quantity):
raise TypeError("time_stamps must be None, a pq.Quantity array or" +
" expose the.times interface")
else:
raise TypeError("time_stamps must be None, a `neo.core.SpikeTrain`"
" or expose the.times interface")

if len(time_stamps) == 0:
return neo.SpikeTrain(time_stamps, units=signal.times.units,
Expand All @@ -139,6 +108,7 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',

# Unpack the extraction interval from tuple or array
extr_left, extr_right = interval

if extr_left > extr_right:
raise ValueError("interval[0] must be < interval[1]")

Expand Down Expand Up @@ -185,36 +155,90 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above',
left_sweep=extr_left)


def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'):
def spike_extraction(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms),
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.

Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
signal : :class:`neo.core.AnalogSignal`
An analog input signal one or more channels.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : :class:`neo.core.SpikeTrain` , optional
Provides the time stamps around which the waveform is extracted. If it
is None, the function `peak_detection` is used to calculate the
`time_stamps` from signal.
Default: None
interval : tuple of :class:`pq.Quantity`
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)
always_as_list: bool, optional
If True, :class:`neo.core.spiketrainslist.SpikeTrainList` is returned.
Default: False

Returns
-------
result_st : neo.SpikeTrain
Contains the spike times of each of the events (spikes) extracted from
the signal.
"""
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`.
Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.

if not isinstance(threshold, pq.Quantity):
raise ValueError('threshold must be a pq.Quantity')
See Also
--------
:func:`elephant.spike_train_generation.peak_detection`
"""
if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return SpikeTrainList(items=(
_spike_extraction_from_single_channel(
signal,
threshold=threshold,
time_stamps=time_stamps,
interval=interval,
sign=sign),))
else:
return _spike_extraction_from_single_channel(
signal, threshold=threshold, time_stamps=time_stamps,
interval=interval, sign=sign)
elif signal.shape[1] > 1:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(
_spike_extraction_from_single_channel(
neo.core.AnalogSignal(
signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign,
time_stamps=time_stamps,
interval=interval,
))
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")

if sign not in ('above', 'below'):
raise ValueError("sign should be 'above' or 'below'")

def _threshold_detection_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above'
) -> neo.core.SpikeTrain:
if sign == 'above':
cutout = np.where(signal > threshold)[0]
else:
Expand Down Expand Up @@ -242,53 +266,88 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'):
return result_st


def peak_detection(signal, threshold=0.0 * pq.mV, sign='above',
as_array=False):
def threshold_detection(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
always_as_list: bool = False,
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Return the peak times for all events that cross threshold.
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.

Parameters
----------
signal : neo.AnalogSignal
An analog input signal.
threshold : pq.Quantity, optional
signal : :class:`neo.core.AnalogSignal`
An analog input signal with one or multiple channels.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Default: 0.*pq.mV
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
as_array : bool, optional
If True, a NumPy array of the resulting peak times is returned instead
of a (default) `neo.SpikeTrain` object.
always_as_list: bool, optional
If True, a :class:`neo.core.spiketrainslist.SpikeTrainList`.
Default: False

Returns
-------
result_st : neo.SpikeTrain
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`
Contains the spike times of each of the events (spikes) extracted from
the signal.
the signal. If `signal` is an AnalogSignal with multiple channels, or
`always_return_list=True` , a
:class:`neo.core.spiketrainlist.SpikeTrainList` is returned.
"""
if not isinstance(threshold, pq.Quantity):
raise ValueError("threshold must be a pq.Quantity")
raise TypeError('threshold must be a pq.Quantity')

if sign not in ('above', 'below'):
raise ValueError("sign should be 'above' or 'below'")

if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return SpikeTrainList(items=(
_threshold_detection_from_single_channel(
signal, threshold=threshold, sign=sign),))
else:
return _threshold_detection_from_single_channel(
signal, threshold=threshold, sign=sign)
elif signal.shape[1] > 1:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(_threshold_detection_from_single_channel(
neo.core.AnalogSignal(signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign)
)
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


# legacy implementation of peak_detection
def _peak_detection_from_single_channel(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above',
as_array: bool = False
) -> neo.core.SpikeTrain:
if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
else:
# sign == 'below'
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign should be 'above' or 'below'")

if len(cutout) == 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# Select the crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
Expand Down Expand Up @@ -327,6 +386,83 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above',
return result_st


def peak_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
as_array: bool = False,
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.

Parameters
----------
signal : :class:`neo.core.AnalogSignal`
An analog input signal or a list of analog input signals.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Default: 0.*pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
as_array : bool, optional
If True, a NumPy array of the resulting peak times is returned instead
of a (default) `neo.SpikeTrain` object.
Default: False
always_as_list: bool, optional
If True, a :class:`neo.core.spiketrainslist.SpikeTrainList` is returned.
Default: False

Returns
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, :class:`neo.core.spiketrainslist.SpikeTrainList`
:class:`np.ndarrav`, List[:class:`np.ndarrav`]
Contains the spike times of each of the events (spikes) extracted from
the signal.
If `signal` is an AnalogSignal with multiple channels or
`always_return_list=True` a list is returned.
"""
if not isinstance(threshold, pq.Quantity):
raise TypeError(
f"threshold must be a pq.Quantity, provided: {type(threshold)}")

if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list and not as_array:
return SpikeTrainList(items=(
_peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign,
as_array=as_array),))
elif always_as_list and as_array:
return [_peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign, as_array=as_array)]
else:
return _peak_detection_from_single_channel(
signal, threshold=threshold, sign=sign, as_array=as_array)
elif signal.shape[1] > 1 and as_array:
return [_peak_detection_from_single_channel(neo.core.AnalogSignal(
signal[:, channel], sampling_rate=signal.sampling_rate),
threshold=threshold,
sign=sign, as_array=as_array
) for channel in range(signal.shape[1])]
elif signal.shape[1] > 1 and not as_array:
spiketrainlist = SpikeTrainList()
for channel in range(signal.shape[1]):
spiketrainlist.append(_peak_detection_from_single_channel(
neo.core.AnalogSignal(signal[:, channel],
sampling_rate=signal.sampling_rate),
threshold=threshold,
sign=sign, as_array=as_array
))
return spiketrainlist
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


class AbstractPointProcess:
"""
Abstract point process to subclass from.
Expand Down
Loading
Loading