Skip to content

Commit

Permalink
Easily calculate firing rate
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jan 22, 2024
1 parent a730682 commit e7edaa2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
37 changes: 37 additions & 0 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import xarray as xr
from non_local_detector.models.base import ClusterlessDetector
from ripple_detection import get_multiunit_population_firing_rate
from track_linearization import get_linearized_position

from spyglass.common.common_interval import IntervalList # noqa: F401
Expand Down Expand Up @@ -418,3 +419,39 @@ def load_spike_data(key, filter_by_interval=True):
new_waveform_features.append(elec_waveform_features[is_in_interval])

return new_spike_times, new_waveform_features

@classmethod
def get_spike_indicator(cls, key, time):
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.load_spike_data(key)[0]
spike_indicator = np.zeros((len(time), len(spike_times)))

for ind, times in enumerate(spike_times):
times = times[np.logical_and(times >= min_time, times <= max_time)]
spike_indicator[:, ind] = np.bincount(
np.digitize(times, time[1:-1]),
minlength=time.shape[0],
)

return spike_indicator

@classmethod
def get_firing_rate(cls, key, time, multiunit=False):
spike_indicator = cls.get_spike_indicator(key, time)
if spike_indicator.ndim == 1:
spike_indicator = spike_indicator[:, np.newaxis]

sampling_frequency = 1 / np.median(np.diff(time))

if multiunit:
spike_indicator = spike_indicator.sum(axis=1, keepdims=True)
return np.stack(
[
get_multiunit_population_firing_rate(
indicator[:, np.newaxis], sampling_frequency
)
for indicator in spike_indicator.T
],
axis=1,
)
39 changes: 38 additions & 1 deletion src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import xarray as xr
from non_local_detector.models.base import SortedSpikesDetector
from ripple_detection import get_multiunit_population_firing_rate
from track_linearization import get_linearized_position

from spyglass.common.common_interval import IntervalList # noqa: F401
Expand Down Expand Up @@ -404,10 +405,46 @@ def load_spike_data(key, filter_by_interval=True):
min_time, max_time = SortedSpikesDecodingV1._get_interval_range(key)

new_spike_times = []
for elec_spike_times in zip(spike_times):
for elec_spike_times in spike_times:
is_in_interval = np.logical_and(
elec_spike_times >= min_time, elec_spike_times <= max_time
)
new_spike_times.append(elec_spike_times[is_in_interval])

return new_spike_times

@classmethod
def get_spike_indicator(cls, key, time):
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.load_spike_data(key)
spike_indicator = np.zeros((len(time), len(spike_times)))

for ind, times in enumerate(spike_times):
times = times[np.logical_and(times >= min_time, times <= max_time)]
spike_indicator[:, ind] = np.bincount(
np.digitize(times, time[1:-1]),
minlength=time.shape[0],
)

return spike_indicator

@classmethod
def get_firing_rate(cls, key, time, multiunit=False):
spike_indicator = cls.get_spike_indicator(key, time)
if spike_indicator.ndim == 1:
spike_indicator = spike_indicator[:, np.newaxis]

sampling_frequency = 1 / np.median(np.diff(time))

if multiunit:
spike_indicator = spike_indicator.sum(axis=1, keepdims=True)
return np.stack(
[
get_multiunit_population_firing_rate(
indicator[:, np.newaxis], sampling_frequency
)
for indicator in spike_indicator.T
],
axis=1,
)
6 changes: 2 additions & 4 deletions src/spyglass/spikesorting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ def get_spike_times(cls, key):
def get_spike_indicator(cls, key, time):
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.get_spike_times(key)
spike_times = cls.load_spike_data(key)
spike_indicator = np.zeros((len(time), len(spike_times)))

for ind, times in enumerate(spike_times):
times = times[
np.logical_and(spike_times >= min_time, spike_times <= max_time)
]
times = times[np.logical_and(times >= min_time, times <= max_time)]
spike_indicator[:, ind] = np.bincount(
np.digitize(times, time[1:-1]),
minlength=time.shape[0],
Expand Down

0 comments on commit e7edaa2

Please sign in to comment.