diff --git a/elephant/conversion.py b/elephant/conversion.py index 0c9615258..a0aae2a28 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -1246,3 +1246,69 @@ def _check_neo_spiketrain(matrix): if isinstance(matrix, (list, tuple)): return all(map(_check_neo_spiketrain, matrix)) return False + + +def discretise_spiketimes(spiketrains, sampling_rate): + """ + Rounds down all spike times in the input spike train(s) + to multiples of the sampling_rate + + Parameters + ---------- + spiketrains : neo.SpikeTrain or list of neo.SpikeTrain + The spiketrain(s) to discretise + sampling_rate : pq.Quantity + The desired sampling rate + + Returns + ------- + neo.SpikeTrain or list of neo.SpikeTrain + The discretised spiketrain(s) + """ + # spiketrains type check + was_single_spiketrain = False + if isinstance(spiketrains, neo.SpikeTrain): + spiketrains = [spiketrains] + was_single_spiketrain = True + elif isinstance(spiketrains, list): + for st in spiketrains: + if not isinstance(st, (np.ndarray, neo.SpikeTrain)): + raise TypeError( + "spiketrains must be a SpikeTrain, a numpy ndarray, or a " + "list of one of those, not %s." % type(spiketrains)) + else: + raise TypeError( + "spiketrains must be a SpikeTrain or a list of SpikeTrain objects," + " not %s." % type(spiketrains)) + + if not isinstance(sampling_rate, pq.Quantity): + raise TypeError( + "The 'sampling_rate' must be pq.Quantity.\n" + "Found: %s." % type(sampling_rate)) + + units = spiketrains[0].times.units + mag_sampling_rate = sampling_rate.rescale(1/units).magnitude.flatten() + + new_spiketrains = [] + for spiketrain in spiketrains: + mag_t_start = spiketrain.t_start.rescale(units).magnitude.flatten() + mag_times = spiketrain.times.magnitude.flatten() + discrete_times = (mag_times // (1 / mag_sampling_rate) + / mag_sampling_rate) + mask = discrete_times < mag_t_start + + if np.any(mask): + warnings.warn(f'{mask.sum()} spike(s) would be before t_start ' + 'and are set to t_start instead.') + discrete_times[mask] = mag_t_start + + discrete_times *= units + new_spiketrain = spiketrain.duplicate_with_new_data(discrete_times) + new_spiketrain.annotations = spiketrain.annotations + new_spiketrain.sampling_rate = sampling_rate + new_spiketrains.append(new_spiketrain) + + if was_single_spiketrain: + new_spiketrains = new_spiketrains[0] + + return new_spiketrains diff --git a/elephant/test/test_conversion.py b/elephant/test/test_conversion.py index db5cae06e..d0fde16b7 100644 --- a/elephant/test/test_conversion.py +++ b/elephant/test/test_conversion.py @@ -702,5 +702,43 @@ def test_binned_spiketrain_rounding(self): np.arange(120000)) +class DiscretiseSpiketrainsTestCase(unittest.TestCase): + def setUp(self): + times = (np.arange(10) + np.random.uniform(size=10)) * pq.ms + self.spiketrains = [neo.SpikeTrain(times, t_stop=10*pq.ms)] * 5 + + def test_list_of_spiketrains(self): + discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, + 1 / pq.ms) + for idx in range(len(self.spiketrains)): + np.testing.assert_array_equal(discretised_spiketrains[idx].times, + np.arange(10) * pq.ms) + + def test_single_spiketrain(self): + discretised_spiketrain = cv.discretise_spiketimes(self.spiketrains[0], + 1 / pq.ms) + np.testing.assert_array_equal(discretised_spiketrain.times, + np.arange(10) * pq.ms) + + def test_preserve_t_start(self): + spiketrain = neo.SpikeTrain([0.7, 5.1]*pq.ms, + t_start=0.5*pq.ms, t_stop=10*pq.ms) + with self.assertWarns(UserWarning): + discretised_spiketrain = cv.discretise_spiketimes(spiketrain, + 1 / pq.ms) + np.testing.assert_array_equal(discretised_spiketrain.times, + [0.5, 5] * pq.ms) + + def test_binning_consistency(self): + discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, + 1 / pq.ms) + bsts = cv.BinnedSpikeTrain(self.spiketrains, + bin_size=1 * pq.ms) + bsts_discretised = cv.BinnedSpikeTrain(discretised_spiketrains, + bin_size=1 * pq.ms) + np.testing.assert_array_equal(bsts.to_array(), + bsts_discretised.to_array()) + + if __name__ == '__main__': unittest.main()