Skip to content

Commit

Permalink
Look into the template option.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 14, 2024
1 parent b5c85ff commit 536c267
Showing 1 changed file with 26 additions and 42 deletions.
68 changes: 26 additions & 42 deletions src/spikeinterface/working/load_kilosort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def compute_spike_amplitude_and_depth(
"""
Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output.
This function is based on code in Nick Steinmetz's `spikes` repository,
This function is based on code in Cortex Lab's's `spikes` repository,
https://github.com/cortex-lab/spikes
Parameters
Expand Down Expand Up @@ -119,54 +119,27 @@ def _get_locations_from_pc_features(params):
Notes
-----
Location of of each individual spike is computed from its low-dimensional projection.
During sorting, kilosort computes the '
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
Taking the first component, the subset of 32 channels associated with this
spike are indexed to get the actual channel locations (in um). Then, the channel
locations are weighted by their PC values.
This function is based on code in Nick Steinmetz's `spikes` repository,
My understanding so far. KS1 paper; The individual spike waveforms are decomposed into
'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA
decompoisition is performed to compute c basis waveforms. Scores for each
channel onto the top three PCs are stored (these recover the waveform well.
This function is based on code in Cortex Lab's `spikes` repository,
https://github.com/cortex-lab/spikes
"""
# Compute spike depths

# for each spike, a PCA is computed just on that spike (n samples x n channels).
# the components are all different between spikes, so are not saved.
# This gives a (n pc = 3, num channels) set of scores.
# but then how it is possible for some spikes to have zero score onto the principal channel?

breakpoint()
pc_features = params["pc_features"][:, 0, :]
pc_features = params["pc_features"][:, 0, :].copy()
pc_features[pc_features < 0] = 0

# Some spikes do not load at all onto the first PC. To avoid biasing the
# dataset by removing these, we repeat the above for the next PC,
# to compute distances for neurons that do not load onto the 1st PC.
# This is not ideal at all, it would be much better to a) find the
# max value for each channel on each of the PCs (i.e. basis vectors).
# Then recompute the estimated waveform peak on each channel by
# summing the PCs by their respective weights. However, the PC basis
# vectors themselves do not appear to be output by KS.

# We include the (n_channels i.e. features) from the second PC
# into the `pc_features` mostly containing the first PC. As all
# operations are per-spike (i.e. row-wise)
no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0)

pc_features_2 = params["pc_features"][:, 1, :]
pc_features_2[pc_features_2 < 0] = 0

pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes]

if np.any(np.sum(pc_features, axis=1) == 0):
# TODO: 1) handle this case for pc_features
# 2) instead use the template_features for all other versions.
raise RuntimeError(
"Some spikes do not load at all onto the first"
"or second principal component. It is necessary"
"to extend this code section to handle more components."
)

# Get the channel indices corresponding to the 32 channels from the PC.
# Get the channel indices corresponding to the channels from the PC.
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]

# Compute the spike locations as the center of mass of the PC scores
Expand Down Expand Up @@ -199,7 +172,7 @@ def get_unwhite_template_info(
Amplitude is calculated for each spike as the template amplitude
multiplied by the `template_scaling_amplitudes`.
This function is based on code in Nick Steinmetz's `spikes` repository,
This function is based on code in Cortex Lab's `spikes` repository,
https://github.com/cortex-lab/spikes
Parameters
Expand Down Expand Up @@ -277,7 +250,7 @@ def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_am
Take the average of all spike amplitudes to get actual template amplitudes
(since tempScalingAmps are equal mean for all templates)
This function is ported from Nick Steinmetz's `spikes` repository,
This function is ported from Cortex Lab's `spikes` repository,
https://github.com/cortex-lab/spikes
"""
num_indices = templates.shape[0]
Expand All @@ -297,7 +270,7 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
"""
Loads the output of Kilosort into a `params` dict.
This function was ported from Nick Steinmetz's `spikes` repository MATLAB
This function was ported from Cortex Lab's `spikes` repository MATLAB
code, https://github.com/cortex-lab/spikes
Parameters
Expand Down Expand Up @@ -343,8 +316,15 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
if load_pcs:
pc_features = np.load(sorter_output / "pc_features.npy")
pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy")

if (sorter_output / "template_features.npy").is_file():
template_features = np.load(sorter_output / "template_features.npy")
template_features_indices = np.load(sorter_output / "templates_ind.npy")
else:
template_features = template_features_indices = None
else:
pc_features = pc_features_indices = None
template_features = template_features_indices = None

# This makes the assumption that there will never be different .csv and .tsv files
# in the same sorter output (this should never happen, there will never even be two).
Expand All @@ -364,6 +344,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool

if load_pcs:
pc_features = pc_features[not_noise_clusters_by_spike, :, :]
if template_features is not None:
template_features = template_features[not_noise_clusters_by_spike, :, :]

spike_clusters = spike_clusters[not_noise_clusters_by_spike]
cluster_ids = cluster_ids[cluster_groups != 0]
Expand All @@ -378,6 +360,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
"spike_clusters": spike_clusters.squeeze(),
"pc_features": pc_features,
"pc_features_indices": pc_features_indices,
"template_features": template_features,
"template_features_indices": template_features_indices,
"temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(),
"cluster_ids": cluster_ids,
"cluster_groups": cluster_groups,
Expand All @@ -399,7 +383,7 @@ def _load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]:
There is some slight formatting differences between the `.tsv` and `.csv`
versions, presumably from different kilosort versions.
This function was ported from Nick Steinmetz's `spikes` repository MATLAB code,
This function was ported from Cortex Lab's `spikes` repository MATLAB code,
https://github.com/cortex-lab/spikes
Parameters
Expand Down

0 comments on commit 536c267

Please sign in to comment.