Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method):
def _compute_similarity_matrix_numpy(
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
):

num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
Expand All @@ -232,15 +234,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support)
overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)
src = src_template[:, local_mask[j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
Expand Down Expand Up @@ -273,9 +276,12 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
import numba

@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method):
def _compute_similarity_matrix_numba(
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
):
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
Expand All @@ -284,7 +290,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = list(range(-num_shifts, 1))
Expand All @@ -304,7 +309,23 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in numba.prange(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))

## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays
## So we inline the function here
# local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support)

if support == "intersection":
local_mask = np.logical_and(
sparsity_mask[i, :], other_sparsity_mask
) # shape (other_num_templates, num_channels)
elif support == "union":
local_mask = np.logical_or(
sparsity_mask[i, :], other_sparsity_mask
) # shape (other_num_templates, num_channels)
elif support == "dense":
local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_)

overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount in range(len(overlapping_templates)):

Expand All @@ -313,8 +334,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].flatten()
tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten()
src = src_template[:, local_mask[j]].flatten()
tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten()

norm_i = 0
norm_j = 0
Expand Down Expand Up @@ -360,6 +381,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
_compute_similarity_matrix = _compute_similarity_matrix_numpy


def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray:

if support == "intersection":
mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
elif support == "union":
mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
elif support == "dense":
mask = np.ones(other_sparsity.shape, dtype=bool)
return mask


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):
Expand All @@ -369,6 +401,8 @@ def compute_similarity_with_templates_array(

all_metrics = ["cosine", "l1", "l2"]

assert support in ["dense", "union", "intersection"], "support should be either dense, union or intersection"

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

Expand All @@ -378,29 +412,25 @@ def compute_similarity_with_templates_array(
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
# num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)
# num_channels = templates_array.shape[2]
# other_num_templates = other_templates_array.shape[0]

if sparsity is not None and other_sparsity is not None:

# make the input more flexible with either The object or the array mask
if sparsity is not None:
sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity
other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity
else:
sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool)

if support == "intersection":
mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
mask[~units_overlaps] = False
if other_sparsity is not None:
other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity
else:
other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool)

assert num_shifts < num_samples, "max_lag is too large"
distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method)
distances = _compute_similarity_matrix(
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support
)

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,23 @@ def test_equal_results_numba(params):
rng = np.random.default_rng(seed=2205)
templates_array = rng.random(size=(4, 20, 5), dtype=np.float32)
other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32)
mask = np.ones((4, 2, 5), dtype=bool)

result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params)
result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params)
sparsity_mask = np.ones((4, 5), dtype=bool)
other_sparsity_mask = np.ones((2, 5), dtype=bool)

result_numpy = _compute_similarity_matrix_numba(
templates_array,
other_templates_array,
sparsity_mask=sparsity_mask,
other_sparsity_mask=other_sparsity_mask,
**params,
)
result_numba = _compute_similarity_matrix_numpy(
templates_array,
other_templates_array,
sparsity_mask=sparsity_mask,
other_sparsity_mask=other_sparsity_mask,
**params,
)

assert np.allclose(result_numpy, result_numba, 1e-3)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates(
assert len(unit_ids) == templates_array.shape[0]

from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array
from scipy.sparse.csgraph import connected_components

similarity = compute_similarity_with_templates_array(
templates_array,
Expand Down
Loading