From 76fd5d19995c01d553c5dc8ce37f38ec2724d915 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:34:59 +0200 Subject: [PATCH 01/13] WIP --- .../postprocessing/template_similarity.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cf0c72952b..31aeedbb24 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,7 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -232,15 +232,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + src = src_template[:, local_mask[j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1) if method == "l1": norm_i = np.sum(np.abs(src)) @@ -273,7 +274,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -304,7 +305,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -313,8 +315,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].flatten() - tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + src = src_template[:, local_mask[j]].flatten() + tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten() norm_i = 0 norm_j = 0 @@ -360,6 +362,34 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy + +def get_mask_for_sparse_template(template_index, + sparsity, + other_sparsity, + support="union") -> np.ndarray: + + other_num_templates = other_sparsity.shape[0] + num_channels = sparsity.shape[1] + + mask = np.ones((other_num_templates, num_channels), dtype=bool) + + if support == "intersection": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + elif support == "union": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + units_overlaps = np.sum(mask, axis=1) > 0 + mask = np.logical_or( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + mask[~units_overlaps] = False + + return mask + + def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): @@ -378,29 +408,17 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - num_templates = templates_array.shape[0] + #num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] - other_num_templates = other_templates_array.shape[0] - - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + #num_channels = templates_array.shape[2] + #other_num_templates = other_templates_array.shape[0] if sparsity is not None and other_sparsity is not None: - - # make the input more flexible with either The object or the array mask sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - - if support == "intersection": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) distances = np.min(distances, axis=0) similarity = 1 - distances From 2e1098a35b5e70cc2ff95a3ce2e00ec7d65a26e8 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:39:18 +0200 Subject: [PATCH 02/13] WIP --- .../postprocessing/template_similarity.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 31aeedbb24..1de81d6b7e 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -277,6 +277,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] + num_channels = sparsity.shape[1] other_num_templates = other_templates_array.shape[0] num_shifts_both_sides = 2 * num_shifts + 1 @@ -285,7 +286,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed - if same_array: # optimisation when array are the same because of symetry in shift shift_loop = list(range(-num_shifts, 1)) @@ -305,7 +305,28 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + + ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays + ## So we inline the function here + #local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) + + if support == "intersection": + local_mask = np.logical_and( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + elif support == "union": + local_mask = np.logical_and( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + units_overlaps = np.sum(local_mask, axis=1) > 0 + local_mask = np.logical_or( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + local_mask[~units_overlaps] = False + + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -371,19 +392,19 @@ def get_mask_for_sparse_template(template_index, other_num_templates = other_sparsity.shape[0] num_channels = sparsity.shape[1] - mask = np.ones((other_num_templates, num_channels), dtype=bool) + mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) if support == "intersection": mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) elif support == "union": mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) units_overlaps = np.sum(mask, axis=1) > 0 mask = np.logical_or( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) mask[~units_overlaps] = False From a37b8f1f38c8cfc5f83baa527bfce3b610dcfdd6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 11:43:31 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/template_similarity.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1de81d6b7e..6ce24c2c00 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,9 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): +def _compute_similarity_matrix_numpy( + templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" +): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -274,7 +276,9 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): + def _compute_similarity_matrix_numba( + templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + ): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] num_channels = sparsity.shape[1] @@ -305,11 +309,11 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - + ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays ## So we inline the function here - #local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - + # local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) if support == "intersection": @@ -325,8 +329,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num sparsity[i], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) local_mask[~units_overlaps] = False - - + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -383,11 +386,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy - -def get_mask_for_sparse_template(template_index, - sparsity, - other_sparsity, - support="union") -> np.ndarray: +def get_mask_for_sparse_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: other_num_templates = other_sparsity.shape[0] num_channels = sparsity.shape[1] @@ -429,17 +428,19 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - #num_templates = templates_array.shape[0] + # num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - #num_channels = templates_array.shape[2] - #other_num_templates = other_templates_array.shape[0] + # num_channels = templates_array.shape[2] + # other_num_templates = other_templates_array.shape[0] if sparsity is not None and other_sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) + distances = _compute_similarity_matrix( + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support + ) distances = np.min(distances, axis=0) similarity = 1 - distances From 40b1f6c517487e6bf5b7fb6963a0aaff42f6c311 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 16:03:37 +0200 Subject: [PATCH 04/13] Fixing tests --- .../postprocessing/template_similarity.py | 5 ++++- .../tests/test_template_similarity.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1de81d6b7e..b0a7445e2e 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -437,7 +437,10 @@ def compute_similarity_with_templates_array( if sparsity is not None and other_sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - + else: + sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) + assert num_shifts < num_samples, "max_lag is too large" distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 9a25af444c..7633e8f3b5 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -107,10 +107,19 @@ def test_equal_results_numba(params): rng = np.random.default_rng(seed=2205) templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) - mask = np.ones((4, 2, 5), dtype=bool) - - result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) - result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + sparsity_mask = np.ones((4, 5), dtype=bool) + other_sparsity_mask = np.ones((2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, + other_templates_array, + sparsity=sparsity_mask, + other_sparsity=other_sparsity_mask, + **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, + other_templates_array, + sparsity=sparsity_mask, + other_sparsity=other_sparsity_mask, + **params) assert np.allclose(result_numpy, result_numba, 1e-3) From d7c2e890ecacaf25d0646de961e3a8423e6b364e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:06:48 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests/test_template_similarity.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 7633e8f3b5..62d4be2318 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -110,16 +110,12 @@ def test_equal_results_numba(params): sparsity_mask = np.ones((4, 5), dtype=bool) other_sparsity_mask = np.ones((2, 5), dtype=bool) - result_numpy = _compute_similarity_matrix_numba(templates_array, - other_templates_array, - sparsity=sparsity_mask, - other_sparsity=other_sparsity_mask, - **params) - result_numba = _compute_similarity_matrix_numpy(templates_array, - other_templates_array, - sparsity=sparsity_mask, - other_sparsity=other_sparsity_mask, - **params) + result_numpy = _compute_similarity_matrix_numba( + templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + ) + result_numba = _compute_similarity_matrix_numpy( + templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + ) assert np.allclose(result_numpy, result_numba, 1e-3) From b844f3e0beeb202c8cac374d4c783fd851c502ed Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 08:47:27 +0200 Subject: [PATCH 06/13] WIP --- src/spikeinterface/postprocessing/template_similarity.py | 7 +++++-- src/spikeinterface/sortingcomponents/clustering/merge.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0d1a8fccb5..090a91abcd 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -433,11 +433,14 @@ def compute_similarity_with_templates_array( # num_channels = templates_array.shape[2] # other_num_templates = other_templates_array.shape[0] - if sparsity is not None and other_sparsity is not None: + if sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity - other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity else: sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + + if other_sparsity is not None: + other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 9110fa37f0..b1956a0e12 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates( assert len(unit_ids) == templates_array.shape[0] from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - from scipy.sparse.csgraph import connected_components similarity = compute_similarity_with_templates_array( templates_array, From f8e3ba9445106ad71d6a980cd44d3a2751f937fc Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 08:47:27 +0200 Subject: [PATCH 07/13] WIP --- .../postprocessing/template_similarity.py | 21 +++++++++++-------- .../tests/test_template_similarity.py | 4 ++-- .../sortingcomponents/clustering/merge.py | 1 - 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0d1a8fccb5..65f75bbb3d 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -209,7 +209,7 @@ def _get_data(self): def _compute_similarity_matrix_numpy( - templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" ): num_templates = templates_array.shape[0] @@ -234,7 +234,7 @@ def _compute_similarity_matrix_numpy( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + local_mask = get_mask_for_sparse_template(i, sparsity_mask, other_sparsity_mask, support=support) overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): @@ -277,11 +277,11 @@ def _compute_similarity_matrix_numpy( @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) def _compute_similarity_matrix_numba( - templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" ): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = sparsity.shape[1] + num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] num_shifts_both_sides = 2 * num_shifts + 1 @@ -318,15 +318,15 @@ def _compute_similarity_matrix_numba( if support == "intersection": local_mask = np.logical_and( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) elif support == "union": local_mask = np.logical_and( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) units_overlaps = np.sum(local_mask, axis=1) > 0 local_mask = np.logical_or( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) local_mask[~units_overlaps] = False @@ -433,11 +433,14 @@ def compute_similarity_with_templates_array( # num_channels = templates_array.shape[2] # other_num_templates = other_templates_array.shape[0] - if sparsity is not None and other_sparsity is not None: + if sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity - other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity else: sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + + if other_sparsity is not None: + other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 62d4be2318..c6663445f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -111,10 +111,10 @@ def test_equal_results_numba(params): other_sparsity_mask = np.ones((2, 5), dtype=bool) result_numpy = _compute_similarity_matrix_numba( - templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params ) result_numba = _compute_similarity_matrix_numpy( - templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params ) assert np.allclose(result_numpy, result_numba, 1e-3) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 9110fa37f0..b1956a0e12 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates( assert len(unit_ids) == templates_array.shape[0] from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - from scipy.sparse.csgraph import connected_components similarity = compute_similarity_with_templates_array( templates_array, From 0aa76a3b679597f3e1fe934784f181114e967a7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 07:01:33 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/tests/test_template_similarity.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index c6663445f8..fa7d19fcbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -111,10 +111,18 @@ def test_equal_results_numba(params): other_sparsity_mask = np.ones((2, 5), dtype=bool) result_numpy = _compute_similarity_matrix_numba( - templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, ) result_numba = _compute_similarity_matrix_numpy( - templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, ) assert np.allclose(result_numpy, result_numba, 1e-3) From 9858fc63518e37161a2d0b65cf069dc3b06b6a14 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 10:13:54 +0200 Subject: [PATCH 09/13] WIP --- .../sortingcomponents/clustering/circus.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index e1bee8e9ff..7a5297aedb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -200,7 +200,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_after, **job_kwargs_local, ) - sparse_mask2 = sparse_mask + + from spikeinterface.core.sparsity import compute_sparsity + sparse_mask2 = compute_sparsity( + templates, + method="snr", + amplitude_mode="peak_to_peak", + noise_levels=params["noise_levels"], + threshold=0.25, + ).mask + else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd From 341d98009cd8c4bfb87816818f85df8823e58697 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:16:39 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7a5297aedb..4555de8148 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -200,8 +200,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_after, **job_kwargs_local, ) - + from spikeinterface.core.sparsity import compute_sparsity + sparse_mask2 = compute_sparsity( templates, method="snr", From 76b9a7b2b409687b79e2b623a812c1e73167602b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 3 Oct 2025 09:04:31 +0200 Subject: [PATCH 11/13] WIP --- .../sortingcomponents/clustering/circus.py | 279 ------------------ 1 file changed, 279 deletions(-) delete mode 100644 src/spikeinterface/sortingcomponents/clustering/circus.py diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py deleted file mode 100644 index 4555de8148..0000000000 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -import importlib -from pathlib import Path - -import numpy as np -import random, string - -from spikeinterface.core import get_global_tmp_folder, Templates -from spikeinterface.core import get_global_tmp_folder -from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.tools import _get_optimal_n_jobs -from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd -from spikeinterface.sortingcomponents.clustering.merge import merge_peak_labels_from_templates -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel - - -class CircusClustering: - """ - Circus clustering is based on several local clustering achieved with a - divide-and-conquer strategy. It uses the `hdbscan` or `isosplit6` clustering algorithms to - perform the local clusterings with an iterative and greedy strategy. - More precisely, it first extracts waveforms from the recording, - then performs a Truncated SVD to reduce the dimensionality of the waveforms. - For every peak, it extracts the SVD features and performs local clustering, grouping the peaks - by channel indices. The clustering is done recursively, and the clusters are merged - based on a similarity metric. The final output is a set of labels for each peak, - indicating the cluster to which it belongs. - """ - - _default_params = { - "clusterer": "hdbscan", # 'isosplit6', 'hdbscan', 'isosplit' - "clusterer_kwargs": { - "min_cluster_size": 20, - "cluster_selection_epsilon": 0.5, - "cluster_selection_method": "leaf", - "allow_single_cluster": True, - }, - "cleaning_kwargs": {}, - "remove_mixtures": False, - "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "recursive_kwargs": { - "recursive": True, - "recursive_depth": 3, - "returns_split_count": True, - }, - "split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9}, - "radius_um": 100, - "neighbors_radius_um": 50, - "n_svd": 5, - "few_waveforms": None, - "ms_before": 2.0, - "ms_after": 2.0, - "seed": None, - "noise_threshold": 2, - "templates_from_svd": True, - "noise_levels": None, - "tmp_folder": None, - "do_merge_with_templates": True, - "merge_kwargs": { - "similarity_metric": "l1", - "num_shifts": 3, - "similarity_thresh": 0.8, - }, - "verbose": True, - "memory_limit": 0.25, - "debug": False, - } - - @classmethod - def main_function(cls, recording, peaks, params, job_kwargs=dict()): - - clusterer = params.get("clusterer", "hdbscan") - assert clusterer in [ - "isosplit6", - "hdbscan", - "isosplit", - ], "Circus clustering only supports isosplit6, isosplit or hdbscan" - if clusterer in ["isosplit6", "hdbscan"]: - have_dep = importlib.util.find_spec(clusterer) is not None - if not have_dep: - raise RuntimeError(f"using {clusterer} as a clusterer needs {clusterer} to be installed") - - d = params - verbose = d["verbose"] - - fs = recording.get_sampling_frequency() - ms_before = params["ms_before"] - ms_after = params["ms_after"] - radius_um = params["radius_um"] - neighbors_radius_um = params["neighbors_radius_um"] - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]).absolute() - - tmp_folder.mkdir(parents=True, exist_ok=True) - - # SVD for time compression - if params["few_waveforms"] is None: - few_peaks = select_peaks( - peaks, - recording=recording, - method="uniform", - seed=params["seed"], - n_peaks=10000, - margin=(nbefore, nafter), - ) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) - wfs = few_wfs[:, :, 0] - else: - offset = int(params["waveforms"]["ms_before"] * fs / 1000) - wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] - - # Ensure all waveforms have a positive max - wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] - - # Remove outliers - valid = np.argmax(np.abs(wfs), axis=1) == nbefore - wfs = wfs[valid] - - from sklearn.decomposition import TruncatedSVD - - svd_model = TruncatedSVD(params["n_svd"], random_state=params["seed"]) - svd_model.fit(wfs) - if params["debug"]: - features_folder = tmp_folder / "tsvd_features" - features_folder.mkdir(exist_ok=True) - else: - features_folder = None - - peaks_svd, sparse_mask, svd_model = extract_peaks_svd( - recording, - peaks, - ms_before=ms_before, - ms_after=ms_after, - svd_model=svd_model, - radius_um=radius_um, - folder=features_folder, - seed=params["seed"], - **job_kwargs, - ) - - neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um - - if params["debug"]: - np.save(features_folder / "sparse_mask.npy", sparse_mask) - np.save(features_folder / "peaks.npy", peaks) - - original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters - - split_kwargs = params["split_kwargs"].copy() - split_kwargs["neighbours_mask"] = neighbours_mask - split_kwargs["waveforms_sparse_mask"] = sparse_mask - split_kwargs["seed"] = params["seed"] - split_kwargs["min_size_split"] = 2 * params["clusterer_kwargs"].get("min_cluster_size", 50) - split_kwargs["clusterer_kwargs"] = params["clusterer_kwargs"] - split_kwargs["clusterer"] = params["clusterer"] - - if params["debug"]: - debug_folder = tmp_folder / "split" - else: - debug_folder = None - - peak_labels, _ = split_clusters( - original_labels, - recording, - {"peaks": peaks, "sparse_tsvd": peaks_svd}, - method="local_feature_clustering", - method_kwargs=split_kwargs, - debug_folder=debug_folder, - **params["recursive_kwargs"], - **job_kwargs, - ) - - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_in_uV=False, **job_kwargs) - - if not params["templates_from_svd"]: - from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording - - job_kwargs_local = job_kwargs.copy() - unit_ids = np.unique(peak_labels) - ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 - job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) - templates = get_templates_from_peaks_and_recording( - recording, - peaks, - peak_labels, - ms_before, - ms_after, - **job_kwargs_local, - ) - - from spikeinterface.core.sparsity import compute_sparsity - - sparse_mask2 = compute_sparsity( - templates, - method="snr", - amplitude_mode="peak_to_peak", - noise_levels=params["noise_levels"], - threshold=0.25, - ).mask - - else: - from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - - templates, sparse_mask2 = get_templates_from_peaks_and_svd( - recording, - peaks, - peak_labels, - ms_before, - ms_after, - svd_model, - peaks_svd, - sparse_mask, - operator="median", - ) - - if params["do_merge_with_templates"]: - peak_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = merge_peak_labels_from_templates( - peaks, - peak_labels, - templates.unit_ids, - templates.templates_array, - sparse_mask2, - **params["merge_kwargs"], - ) - - templates = Templates( - templates_array=merge_template_array, - sampling_frequency=fs, - nbefore=templates.nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=new_unit_ids, - probe=recording.get_probe(), - is_in_uV=False, - ) - - labels = templates.unit_ids - - if params["debug"]: - templates_folder = tmp_folder / "dense_templates" - templates.to_zarr(folder_path=templates_folder) - - if params["remove_mixtures"]: - if verbose: - print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - - cleaning_job_kwargs = job_kwargs.copy() - cleaning_job_kwargs["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() - - labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - ) - - if verbose: - print("Kept %d non-duplicated clusters" % len(labels)) - else: - if verbose: - print("Kept %d raw clusters" % len(labels)) - - more_outs = dict( - svd_model=svd_model, - peaks_svd=peaks_svd, - peak_svd_sparse_mask=sparse_mask, - ) - return labels, peak_labels, more_outs From b43628bf9e4bcfe2e82f909ae6f56c980d9f487b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 7 Oct 2025 09:31:22 +0200 Subject: [PATCH 12/13] WIP --- .../postprocessing/template_similarity.py | 46 +++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 65f75bbb3d..5d7b52c6ec 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -234,7 +234,7 @@ def _compute_similarity_matrix_numpy( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity_mask, other_sparsity_mask, support=support) + local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support) overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): @@ -312,23 +312,18 @@ def _compute_similarity_matrix_numba( ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays ## So we inline the function here - # local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - - local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) + # local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support) if support == "intersection": local_mask = np.logical_and( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) + sparsity_mask[i, :], other_sparsity_mask + ) # shape (other_num_templates, num_channels) elif support == "union": - local_mask = np.logical_and( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) - units_overlaps = np.sum(local_mask, axis=1) > 0 local_mask = np.logical_or( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) - local_mask[~units_overlaps] = False + sparsity_mask[i, :], other_sparsity_mask + ) # shape (other_num_templates, num_channels) + elif support == "dense": + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] @@ -386,27 +381,18 @@ def _compute_similarity_matrix_numba( _compute_similarity_matrix = _compute_similarity_matrix_numpy -def get_mask_for_sparse_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: - - other_num_templates = other_sparsity.shape[0] - num_channels = sparsity.shape[1] - - mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) +def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: if support == "intersection": mask = np.logical_and( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) + sparsity[template_index, :], other_sparsity + ) # shape (other_num_templates, num_channels) elif support == "union": - mask = np.logical_and( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) - units_overlaps = np.sum(mask, axis=1) > 0 mask = np.logical_or( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) - mask[~units_overlaps] = False - + sparsity[template_index, :], other_sparsity + ) # shape (other_num_templates, num_channels) + elif support == "dense": + mask = np.ones(other_sparsity.shape, dtype=bool) return mask @@ -419,6 +405,8 @@ def compute_similarity_with_templates_array( all_metrics = ["cosine", "l1", "l2"] + assert support in ["dense", "union", "intersection"], "support should be either dense, union or intersection" + if method not in all_metrics: raise ValueError(f"compute_template_similarity (method {method}) not exists") From 48769173edf2834ce9ecdd528dee8c80c3953cba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 07:31:59 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/template_similarity.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 5d7b52c6ec..aed01b6a2c 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -384,13 +384,9 @@ def _compute_similarity_matrix_numba( def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: if support == "intersection": - mask = np.logical_and( - sparsity[template_index, :], other_sparsity - ) # shape (other_num_templates, num_channels) + mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) elif support == "union": - mask = np.logical_or( - sparsity[template_index, :], other_sparsity - ) # shape (other_num_templates, num_channels) + mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) elif support == "dense": mask = np.ones(other_sparsity.shape, dtype=bool) return mask