Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a function to discretise spiketimes #454

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,3 +1246,64 @@ def _check_neo_spiketrain(matrix):
if isinstance(matrix, (list, tuple)):
return all(map(_check_neo_spiketrain, matrix))
return False


def discretise_spiketimes(spiketrains, sampling_rate):
"""
Rounds down all spike times in the input spike train(s)
to multiples of the sampling_rate

Parameters
----------
spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
The spiketrain(s) to discretise
sampling_rate : pq.Quantity
The desired sampling rate

Returns
-------
neo.SpikeTrain or list of neo.SpikeTrain
The discretised spiketrain(s)
"""
# spiketrains type check
was_single_spiketrain = False
if isinstance(spiketrains, (neo.SpikeTrain)):
Kleinjohann marked this conversation as resolved.
Show resolved Hide resolved
spiketrains = [spiketrains]
was_single_spiketrain = True
elif isinstance(spiketrains, list):
for st in spiketrains:
if not isinstance(st, (np.ndarray, neo.SpikeTrain)):
raise TypeError(
"spiketrains must be a SpikeTrain, a numpy ndarray, or a "
"list of one of those, not %s." % type(spiketrains))
else:
raise TypeError(
"spiketrains must be a SpikeTrain or a list of SpikeTrain objects,"
" not %s." % type(spiketrains))

Kleinjohann marked this conversation as resolved.
Show resolved Hide resolved
units = spiketrains[0].times.units
mag_sampling_rate = sampling_rate.rescale(1/units).magnitude.flatten()

new_spiketrains = []
for spiketrain in spiketrains:
mag_t_start = spiketrain.t_start.rescale(units).magnitude.flatten()
mag_times = spiketrain.times.magnitude.flatten()
discrete_times = (mag_times // (1 / mag_sampling_rate)
/ mag_sampling_rate)
mask = discrete_times < mag_t_start

if np.any(mask):
warnings.warn(f'{mask.sum()} spike(s) would be before t_start '
'and are set to t_start instead.')
discrete_times[mask] = mag_t_start

discrete_times *= units
new_spiketrain = spiketrain.duplicate_with_new_data(discrete_times)
new_spiketrain.annotations = spiketrain.annotations
new_spiketrain.sampling_rate = sampling_rate
new_spiketrains.append(new_spiketrain)

if was_single_spiketrain:
new_spiketrains = new_spiketrains[0]

return new_spiketrains
38 changes: 38 additions & 0 deletions elephant/test/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,5 +702,43 @@ def test_binned_spiketrain_rounding(self):
np.arange(120000))


class DiscretiseSpiketrainsTestCase(unittest.TestCase):
def setUp(self):
times = (np.arange(10) + np.random.uniform(size=10)) * pq.ms
self.spiketrains = [neo.SpikeTrain(times, t_stop=10*pq.ms)] * 5

def test_list_of_spiketrains(self):
discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains,
1 / pq.ms)
for idx in range(len(self.spiketrains)):
np.testing.assert_array_equal(discretised_spiketrains[idx].times,
np.arange(10) * pq.ms)

def test_single_spiketrain(self):
discretised_spiketrain = cv.discretise_spiketimes(self.spiketrains[0],
1 / pq.ms)
np.testing.assert_array_equal(discretised_spiketrain.times,
np.arange(10) * pq.ms)

def test_preserve_t_start(self):
spiketrain = neo.SpikeTrain([0.7, 5.1]*pq.ms,
t_start=0.5*pq.ms, t_stop=10*pq.ms)
with self.assertWarns(UserWarning):
discretised_spiketrain = cv.discretise_spiketimes(spiketrain,
1 / pq.ms)
np.testing.assert_array_equal(discretised_spiketrain.times,
[0.5, 5] * pq.ms)

def test_binning_consistency(self):
discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains,
1 / pq.ms)
bsts = cv.BinnedSpikeTrain(self.spiketrains,
bin_size=1 * pq.ms)
bsts_discretised = cv.BinnedSpikeTrain(discretised_spiketrains,
bin_size=1 * pq.ms)
np.testing.assert_array_equal(bsts.to_array(),
bsts_discretised.to_array())


if __name__ == '__main__':
unittest.main()