diff --git a/neurodsp/burst/utils.py b/neurodsp/burst/utils.py index 60b43b78..a80b890b 100644 --- a/neurodsp/burst/utils.py +++ b/neurodsp/burst/utils.py @@ -42,22 +42,28 @@ def compute_burst_stats(bursting, fs): tot_time = len(bursting) / fs - starts = np.array([]) - ends = np.array([]) + change = np.diff(bursting) + idcs, = change.nonzero() - for ii, index in enumerate(np.where(np.diff(bursting) != 0)[0]): + idcs += 1 # Get indices following the change. - if (ii % 2) == 0: - starts = np.append(starts, index) - else: - ends = np.append(ends, index) + if bursting[0]: + # If the first sample is part of a burst, prepend a zero. + idcs = np.r_[0, idcs] + if bursting[-1]: + # If the last sample is part of a burst, append an index corresponding + # to the length of signal. + idcs = np.r_[idcs, bursting.size] + + starts = idcs[0::2] + ends = idcs[1::2] durations = (ends - starts) / fs - stats_dict = {'n_bursts': len(starts), - 'duration_mean': np.mean(durations), - 'duration_std': np.std(durations), - 'percent_burst': 100 * np.sum(bursting) / len(bursting), - 'bursts_per_second': len(starts) / tot_time} + stats_dict = {'n_bursts': durations.size, + 'duration_mean': durations.mean(), + 'duration_std': durations.std(), + 'percent_burst': 100 * sum(bursting) / len(bursting), + 'bursts_per_second': durations.size / tot_time} return stats_dict diff --git a/neurodsp/tests/burst/test_utils.py b/neurodsp/tests/burst/test_utils.py index 560eb808..38765b4a 100644 --- a/neurodsp/tests/burst/test_utils.py +++ b/neurodsp/tests/burst/test_utils.py @@ -1,16 +1,18 @@ """Tests for burst detection functions.""" from neurodsp.burst.utils import * +import pytest ################################################################################################### ################################################################################################### -def test_compute_burst_stats(): +@pytest.mark.parametrize('bursting, n_bursts, duration_mean, percent_burst', + [(np.array([False, False, True, True, False]), 1, 2, 40), + (np.array([True, False, False, True, False, True]), 3, 1, 50)]) +def test_compute_burst_stats(bursting, n_bursts, duration_mean, percent_burst): - bursts = np.array([False, False, True, True, False]) + stats = compute_burst_stats(bursting, 1) - stats = compute_burst_stats(bursts, 1) - - assert stats['n_bursts'] == 1 - assert stats['duration_mean'] == 2 - assert stats['percent_burst'] == 40.0 + assert stats['n_bursts'] == n_bursts + assert stats['duration_mean'] == duration_mean + assert stats['percent_burst'] == percent_burst