diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py new file mode 100644 index 00000000..da4c4538 --- /dev/null +++ b/draft_pynapple_fastplotlib.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +""" +Fastplotlib +=========== + +Working with calcium data. + +For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. + +The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo. + +""" +# %% +# !!! warning +# This tutorial uses seaborn and matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib seaborn tqdm` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +# %qui qt + +import pynapple as nap +import numpy as np +import fastplotlib as fpl + +import imageio.v3 as iio +import sys +# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' + +#nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb") +nwb = nap.load_file("your/path/to/MyProject/sub-A2929/ses-A2929-200711/pynapplenwb/A2929-200711.nwb") + +units = nwb['units']#.getby_category("location")['adn'] + +tmp = units.to_tsd() + +tmp = np.vstack((tmp.index.values, tmp.values)).T + +# Example 1 + +fplot = fpl.Plot() + +fplot.add_scatter(tmp) + +fplot.graphics[0].cmap = "jet" + +fplot.graphics[0].cmap.values = tmp[:, 1] + +fplot.show(maintain_aspect=False) + +# Example 2 + +names = [['raster'], ['position']] + +grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names) + +grid_plot['raster'].add_scatter(tmp) + +grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T) + +grid_plot.show(maintain_aspect=False) + +grid_plot['raster'].auto_scale(maintain_aspect=False) + + +# Example 3 +#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi") +#frames = frames[:,:,:,0] +frames = np.random.randn(10, 100, 100) + +iw = fpl.ImageWidget(frames, cmap="gnuplot2") + +#iw.show() + +# Example 4 + +from PyQt6 import QtWidgets + + +mainwidget = QtWidgets.QWidget() + +hlayout = QtWidgets.QHBoxLayout(mainwidget) + +iw.widget.setParent(mainwidget) + +hlayout.addWidget(iw.widget) + +grid_plot.widget.setParent(mainwidget) + +hlayout.addWidget(grid_plot.widget) + +mainwidget.show() diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 340451c5..8679adfa 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-12 16:50:36 +# @Last Modified time: 2024-01-25 16:43:34 import numpy as np from numba import jit, njit, prange @@ -866,7 +866,7 @@ def jitcontinuous_perievent( (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan ) - if np.all((count[:, 0] * count[:, 1]) > 0): + if np.any((count[:, 0] * count[:, 1]) > 0): for k in prange(N_epochs): if count[k, 0] > 0 and count[k, 1] > 0: t = start_t[k, 0] @@ -891,9 +891,9 @@ def jitcontinuous_perievent( left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) right = np.minimum(windowsize[1], maxt - t_pos - 1) center = windowsize[0] + 1 - new_data_array[ - center - left - 1 : center + right, cnt_i - ] = data_array[t_pos - left : t_pos + right + 1] + new_data_array[center - left - 1 : center + right, cnt_i] = ( + data_array[t_pos - left : t_pos + right + 1] + ) t -= 1 i += 1 @@ -902,15 +902,103 @@ def jitcontinuous_perievent( return new_data_array -# time_array = tsd.t -# time_target_array = tref.t -# data_array = tsd.d +@jit(nopython=True) +def jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, +): + T = time_array.shape[0] + N = count_array.shape[1] + N_epochs = len(starts) + + time_target_array, data_target_array, count = jitrestrict_with_count( + time_target_array, data_target_array, starts, ends + ) + max_count = np.cumsum(count) + + new_data_array = np.full( + (int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]), + 0.0, + ) + + t = 0 # count events + + hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:])) + + for k in range(N_epochs): + if count[k] > 0: + t_start = t + maxi = max_count[k] + i = maxi - count[k] -# for i,t in enumerate(tref.restrict(ep).t): -# plot(time_idx + t, new_data_array[:,i]+i*2.0, 'o') -# plot(tsd + i*2.0, color='grey') -# [axvspan(ep.loc[i,'start'], ep.loc[i,'end'], alpha=0.3) for i in range(len(ep))] -# [axvline(t) for t in tref.restrict(ep).t] + while t < T: + lbound = time_array[t] + rbound = np.round(lbound + binsize, 9) + + if time_target_array[i] < rbound: + i_start = i + i_stop = i + + while i_stop < maxi: + if time_target_array[i_stop] < rbound: + i_stop += 1 + else: + break + + while i_start < i_stop - 1: + if time_target_array[i_start] < lbound: + i_start += 1 + else: + break + v = np.sum(data_target_array[i_start:i_stop], 0) / float( + i_stop - i_start + ) + + checknan = np.sum(v) + if not np.isnan(checknan): + hankel_array[-1] = v + + if t - t_start >= windows[1]: + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1], n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + t += 1 + + i = i_start + + if t == T or time_array[t] > ends[k]: + if t - t_start > windows[1]: + for j in range(windows[1]): + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1] + j, n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + hankel_array *= 0.0 + break + + total = np.sum(count_array, 0) + for n in range(N): + if total[n] > 0.0: + new_data_array[:, n] /= total[n] + + return new_data_array # @jit(nopython=True) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index f048834c..fc4ccad0 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-07 13:58:06 +# @Last Modified time: 2024-01-08 16:09:01 """ diff --git a/pynapple/core/time_units.py b/pynapple/core/time_units.py index b669ce31..b9137316 100644 --- a/pynapple/core/time_units.py +++ b/pynapple/core/time_units.py @@ -8,6 +8,7 @@ - 's': seconds (overall default) """ + from warnings import warn import numpy as np diff --git a/pynapple/io/phy.py b/pynapple/io/phy.py index 4c2ddc39..ae0b79a8 100644 --- a/pynapple/io/phy.py +++ b/pynapple/io/phy.py @@ -5,6 +5,7 @@ @author: Sara Mahallati, Guillaume Viejo """ + import os import numpy as np diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index d4a3cf95..ac5e5b1e 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -2,10 +2,9 @@ # @Author: gviejo # @Date: 2022-01-30 22:59:00 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-12 18:17:42 +# @Last Modified time: 2024-01-26 15:52:19 import numpy as np -from scipy.linalg import hankel from .. import core as nap @@ -193,59 +192,68 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): def compute_event_trigger_average( - group, feature, binsize, windowsize, ep, time_unit="s" + group, + feature, + binsize, + windowsize=None, + ep=None, + time_unit="s", ): """ - Bin the spike train in binsize and compute the Event Trigger Average (ETA) within windowsize. - If C is the spike count matrix and `feature` is a Tsd array, the function computes + Bin the event timestamps within binsize and compute the Event Trigger Average (ETA) within windowsize. + If C is the event count matrix and `feature` is a Tsd array, the function computes the Hankel matrix H from windowsize=(-t1,+t2) by offseting the Tsd array. The ETA is then defined as the dot product between H and C divided by the number of events. + The object feature can be any dimensions. + Parameters ---------- group : TsGroup The group of Ts/Tsd objects that hold the trigger time. - feature : Tsd - The 1-dimensional feature to average. Can be a TsdFrame with one column only. + feature : Tsd, TsdFrame or TsdTensor + The feature to average. binsize : float or int The bin size. Default is second. If different, specify with the parameter time_unit ('s' [default], 'ms', 'us'). - windowsize : tuple or list of float - The window size. Default is second. For example (-1, 1). + windowsize : tuple of float/int or float/int + The window size. Default is second. For example windowsize = (-1, 1) is equivalent to windowsize = 1 If different, specify with the parameter time_unit ('s' [default], 'ms', 'us'). ep : IntervalSet - The epoch on which ETA are computed + The epochs on which the average is computed time_unit : str, optional The time unit of the parameters. They have to be consistent for binsize and windowsize. ('s' [default], 'ms', 'us'). - - Returns - ------- - TsdFrame - A TsdFrame of Event-Trigger Average. Each column is an element from the group. - - Raises - ------ - RuntimeError - if group is not a Ts/Tsd or TsGroup """ assert isinstance(group, nap.TsGroup), "group should be a TsGroup." assert isinstance( - windowsize, (float, int, tuple) - ), "windowsize should be a tuple or int or float." + feature, (nap.Tsd, nap.TsdFrame, nap.TsdTensor) + ), "Feature should be a Tsd, TsdFrame or TsdTensor" assert isinstance(binsize, (float, int)), "binsize should be int or float." assert isinstance(time_unit, str), "time_unit should be a str." assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'" - assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object." - if isinstance(feature, nap.TsdFrame): - if feature.shape[1] == 1: - feature = feature[:, 0] + if windowsize is not None: + if isinstance(windowsize, tuple): + assert ( + len(windowsize) == 2 + ), "windowsize should be a tuple of 2 elements (-t, +t)" + assert all( + [isinstance(t, (float, int)) for t in windowsize] + ), "windowsize should be a tuple of int/float" + else: + assert isinstance( + windowsize, (float, int) + ), "windowsize should be a tuple of int/float or int/float." + windowsize = (windowsize, windowsize) + else: + windowsize = (0.0, 0.0) - assert isinstance( - feature, nap.Tsd - ), "Feature should be a Tsd or a TsdFrame with one column" + if ep is not None: + assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object." + else: + ep = feature.time_support binsize = nap.TsIndex.format_timestamps( np.array([binsize], dtype=np.float64), time_unit @@ -260,29 +268,51 @@ def compute_event_trigger_average( np.array([windowsize[1]], dtype=np.float64), time_unit )[0] ) + idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] idx2 = np.arange(0, end + binsize, binsize)[1:] time_idx = np.hstack((idx1, np.zeros(1), idx2)) - count = group.count(binsize, ep) - - tmp = feature.bin_average(binsize, ep) - - # Check for any NaNs in feature - if np.any(np.isnan(tmp)): - tmp = tmp.dropna() - count = count.restrict(tmp.time_support) - - # Build the Hankel matrix - n_p = len(idx1) - n_f = len(idx2) - pad_tmp = np.pad(tmp, (n_p, n_f)) - offset_tmp = hankel(pad_tmp, pad_tmp[-(n_p + n_f + 1) :])[0 : len(tmp)] + eta = np.zeros((time_idx.shape[0], len(group), *feature.shape[1:])) - sta = np.dot(offset_tmp.T, count.values) + windows = np.array([len(idx1), len(idx2)]) - sta = sta / np.sum(count, 0) + # Bin the spike train + count = group.count(binsize, ep) - sta = nap.TsdFrame(t=time_idx, d=sta, columns=group.index) + time_array = np.round(count.index.values - (binsize / 2), 9) + count_array = count.values + starts = ep.start.values + ends = ep.end.values - return sta + time_target_array = feature.index.values + data_target_array = feature.values + + if data_target_array.ndim == 1: + eta = nap.jitted_functions.jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + np.expand_dims(data_target_array, -1), + starts, + ends, + windows, + binsize, + ) + eta = np.squeeze(eta, -1) + else: + eta = nap.jitted_functions.jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, + ) + + if eta.ndim == 2: + return nap.TsdFrame(t=time_idx, d=eta, columns=group.index) + else: + return nap.TsdTensor(t=time_idx, d=eta) diff --git a/pynapple/process/tuning_curves.py b/pynapple/process/tuning_curves.py index a29bbee7..3311c119 100644 --- a/pynapple/process/tuning_curves.py +++ b/pynapple/process/tuning_curves.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -"""Summary +""" """ # @Author: gviejo # @Date: 2022-01-02 23:33:42 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-10 14:20:44 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2024-01-26 15:28:51 import warnings @@ -16,7 +16,7 @@ def compute_discrete_tuning_curves(group, dict_ep): """ - Compute discrete tuning curves of a TsGroup using a dictionnary of epochs. + Compute discrete tuning curves of a TsGroup using a dictionnary of epochs. The function returns a pandas DataFrame with each row being a key of the dictionnary of epochs and each column being a neurons. @@ -52,17 +52,19 @@ def compute_discrete_tuning_curves(group, dict_ep): RuntimeError If group is not a TsGroup object. """ - if not isinstance(group, nap.TsGroup): - raise RuntimeError("Unknown format for group") - + assert isinstance(group, nap.TsGroup), "group should be a TsGroup." + assert isinstance(dict_ep, dict), "dict_ep should be a dictionnary of IntervalSet" idx = np.sort(list(dict_ep.keys())) + for k in idx: + assert isinstance( + dict_ep[k], nap.IntervalSet + ), "dict_ep argument should contain only IntervalSet. Key {} in dict_ep is not an IntervalSet".format( + k + ) - tuning_curves = pd.DataFrame(index=idx, columns=list(group.keys()), data=0) + tuning_curves = pd.DataFrame(index=idx, columns=list(group.keys()), data=0.0) for k in dict_ep.keys(): - if not isinstance(dict_ep[k], nap.IntervalSet): - raise RuntimeError("Key {} in dict_ep is not an IntervalSet".format(k)) - for n in group.keys(): tuning_curves.loc[k, n] = float(len(group[n].restrict(dict_ep[k]))) @@ -79,7 +81,7 @@ def compute_1d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): ---------- group : TsGroup The group of Ts/Tsd for which the tuning curves will be computed - feature : Tsd + feature : Tsd (or TsdFrame with 1 column only) The 1-dimensional target feature (e.g. head-direction) nb_bins : int Number of bins in the tuning curve @@ -101,13 +103,27 @@ def compute_1d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): If group is not a TsGroup object. """ - if not isinstance(group, nap.TsGroup): - raise RuntimeError("Unknown format for group") + assert isinstance(group, nap.TsGroup), "group should be a TsGroup." + assert isinstance( + feature, (nap.Tsd, nap.TsdFrame) + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + if isinstance(feature, nap.TsdFrame): + assert ( + feature.shape[1] == 1 + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + assert isinstance(nb_bins, int) + + if ep is None: + ep = feature.time_support + else: + assert isinstance(ep, nap.IntervalSet), "ep should be an IntervalSet" if minmax is None: bins = np.linspace(np.min(feature), np.max(feature), nb_bins + 1) else: + assert isinstance(minmax, tuple), "minmax should be a tuple of boundaries" bins = np.linspace(minmax[0], minmax[1], nb_bins + 1) + idx = bins[0:-1] + np.diff(bins) / 2 tuning_curves = pd.DataFrame(index=idx, columns=list(group.keys())) @@ -162,16 +178,19 @@ def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): If group is not a TsGroup object or if feature is not 2 columns only. """ - if feature.shape[1] != 2: - raise RuntimeError("feature should have 2 columns only.") - - if type(group) is not nap.TsGroup: - raise RuntimeError("Unknown format for group") - - if isinstance(ep, nap.IntervalSet): - feature = feature.restrict(ep) - else: + assert isinstance(group, nap.TsGroup), "group should be a TsGroup." + assert isinstance( + feature, nap.TsdFrame + ), "feature should be a TsdFrame with 2 columns" + if isinstance(feature, nap.TsdFrame): + assert feature.shape[1] == 2, "feature should have 2 columns only." + assert isinstance(nb_bins, int) + + if ep is None: ep = feature.time_support + else: + assert isinstance(ep, nap.IntervalSet), "ep should be an IntervalSet" + feature = feature.restrict(ep) cols = list(feature.columns) groups_value = {} @@ -184,6 +203,7 @@ def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): np.min(feature.loc[c]), np.max(feature.loc[c]), nb_bins + 1 ) else: + assert isinstance(minmax, tuple), "minmax should be a tuple of 4 elements" bins = np.linspace(minmax[i + i % 2], minmax[i + 1 + i % 2], nb_bins + 1) binsxy[c] = bins diff --git a/pyproject.toml b/pyproject.toml index d9863274..7cde031d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ repository = "https://github.com/pynapple-org/pynapple" ########################################################################## [project.optional-dependencies] dev = [ - "black", # Code formatter + "black>=24.1.0", # Code formatter "isort", # Import sorter "pip-tools", # Dependency management "pytest", # Testing framework diff --git a/test_fastplotlib.py b/test_fastplotlib.py new file mode 100644 index 00000000..36dae64b --- /dev/null +++ b/test_fastplotlib.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# @Author: Guillaume Viejo +# @Date: 2023-10-31 18:34:19 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-10-31 19:24:48 + +import fastplotlib as fpl +import pynapple as nap +import numpy as np +import sys, os +sys.path.append(os.path.expanduser("~/fastplotlib-sfn2023")) +from _video import LazyVideo +from pathlib import Path +from ipywidgets import HBox + +behavior_path = Path('/mnt/home/gviejo/fastplotlib-sfn2023/sample_data/M238Slc17a7_Chr2/20170824') + +paths_side = sorted(behavior_path.glob("*side_v*.avi")) +paths_front = sorted(behavior_path.glob("*front_v*.avi")) + + +class Concat: + def __init__(self, files): + self.files = files + self.videos = [LazyVideo(p) for p in self.files] + self._nframes_per_video = [v.shape[0] for v in self.videos] + self._cumsum = np.cumsum(self._nframes_per_video) + self.nframes = sum(self._nframes_per_video) + self.shape = (self.nframes, self.videos[0].shape[1], self.videos[0].shape[2]) + self.ndim = 3 + + self.dtype = self.videos[0].dtype + + def __len__(self) -> int: + return self.nframes + + def _get_vid_ix_sub_ix(self, key): + vid_ix = np.searchsorted(self._cumsum, key) + if vid_ix != 0: + sub_ix = key - self._cumsum[vid_ix - 1] + else: + sub_ix = key + + return vid_ix, sub_ix + + def __getitem__(self, key)-> np.ndarray: + if isinstance(key, slice): + start, stop = key.start, key.stop + vid_ix, sub_ix0 = self._get_vid_ix_sub_ix(start) + vid_ix, sub_ix1 = self._get_vid_ix_sub_ix(stop) + return self.videos[vid_ix][sub_ix0:sub_ix1] + elif isinstance(key, int): + vid_ix, sub_ix0 = self._get_vid_ix_sub_ix(key) + return self.videos[vid_ix][sub_ix0] + + + +concat = Concat(paths_side) + +# print(concat.videos) + +t = np.linspace(0, concat.nframes / 500, concat.nframes) + +tsd_video = nap.TsdTensor(t, concat) + +v = LazyVideo(concat.files[0]) + +tsd = nap.TsdTensor(t=np.arange(0, len(v)), d=v) \ No newline at end of file diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index 38733d17..f422d255 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-08-29 17:27:02 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-12 18:10:30 +# @Last Modified time: 2024-01-25 11:39:01 #!/usr/bin/env python """Tests of spike trigger average for `pynapple` package.""" @@ -14,7 +14,7 @@ # from matplotlib.pyplot import * -def test_compute_spike_trigger_average(): +def test_compute_spike_trigger_average_tsd(): ep = nap.IntervalSet(0, 100) feature = nap.Tsd( t=np.arange(0, 101, 0.01), d=np.zeros(int(101 / 0.01)), time_support=ep @@ -37,12 +37,85 @@ def test_compute_spike_trigger_average(): assert sta.shape == output.shape np.testing.assert_array_almost_equal(sta, output) +def test_compute_spike_trigger_average_tsdframe(): + ep = nap.IntervalSet(0, 100) feature = nap.TsdFrame( - t=feature.index.values, d=feature.values[:,None], time_support=ep + t=np.arange(0, 101, 0.01), d=np.zeros((int(101 / 0.01),1)), time_support=ep + ) + t1 = np.arange(1, 100) + x = np.arange(100, 10000, 100) + feature[x] = 1.0 + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep + ) + + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) + + output = np.zeros((7, 3)) + output[3, 0] = 0.05 + output[4, 1] = 0.05 + output[2, 2] = 0.05 + + assert isinstance(sta, nap.TsdTensor) + assert sta.shape == (*output.shape, 1) + np.testing.assert_array_almost_equal(sta, np.expand_dims(output, 2)) + +def test_compute_spike_trigger_average_tsdtensor(): + ep = nap.IntervalSet(0, 100) + feature = nap.TsdTensor( + t=np.arange(0, 101, 0.01), d=np.zeros((int(101 / 0.01),1,1)), time_support=ep + ) + t1 = np.arange(1, 100) + x = np.arange(100, 10000, 100) + feature[x] = 1.0 + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) + + output = np.zeros((7, 3, 1, 1)) + output[3, 0] = 0.05 + output[4, 1] = 0.05 + output[2, 2] = 0.05 + + assert isinstance(sta, nap.TsdTensor) + assert sta.shape == output.shape np.testing.assert_array_almost_equal(sta, output) +def test_compute_spike_trigger_average_random_feature(): + ep = nap.IntervalSet(0, 100) + feature = nap.Tsd( + t=np.arange(0, 100, 0.001), d=np.random.randn(100000), time_support=ep + ) + t1 = np.sort(np.random.uniform(0, 100, 1000)) + spikes = nap.TsGroup( + {0: nap.Ts(t1)}, time_support=ep + ) + + group = spikes + binsize = 0.1 + windowsize = (1.0, 1.0) + + sta = nap.compute_event_trigger_average(spikes, feature, binsize, windowsize, ep) + + start, end = windowsize + idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] + idx2 = np.arange(0, end + binsize, binsize)[1:] + time_idx = np.hstack((idx1, np.zeros(1), idx2)) + count = group.count(binsize, ep) + tmp = feature.bin_average(binsize, ep) + from scipy.linalg import hankel + # Build the Hankel matrix + n_p = len(idx1) + n_f = len(idx2) + pad_tmp = np.pad(tmp.values, (n_p, n_f)) + offset_tmp = hankel(pad_tmp, pad_tmp[-(n_p + n_f + 1) :])[0 : len(tmp)] + sta2 = np.dot(offset_tmp.T, count.values) + sta2 = sta2 / np.sum(count.values, 0) + + np.testing.assert_array_almost_equal(sta.values, sta2) + def test_compute_spike_trigger_average_add_nan(): ep = nap.IntervalSet(0, 110) feature = nap.Tsd( @@ -75,23 +148,42 @@ def test_compute_spike_trigger_average_raise_error(): ) t1 = np.arange(1, 101) + 0.01 x = np.arange(100, 10000, 100)+1 - feature[x] = 1.0 + feature[x] = 1.0 + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep + ) with pytest.raises(Exception) as e_info: nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep) assert str(e_info.value) == "group should be a TsGroup." - feature = nap.TsdFrame( - t=np.arange(0, 101, 0.01), d=np.random.rand(int(101 / 0.01), 3), time_support=ep - ) - spikes = nap.TsGroup( - {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep - ) with pytest.raises(Exception) as e_info: - nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep) - assert str(e_info.value) == "Feature should be a Tsd or a TsdFrame with one column" + nap.compute_event_trigger_average(spikes, np.array(10), 0.1, (0.5, 0.5), ep) + assert str(e_info.value) == "Feature should be a Tsd, TsdFrame or TsdTensor" + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, "0.1", (0.5, 0.5), ep) + assert str(e_info.value) == "binsize should be int or float." + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep, time_unit=1) + assert str(e_info.value) == "time_unit should be a str." + + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep, time_unit="a") + assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'" + + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5, 0.5), ep) + assert str(e_info.value) == "windowsize should be a tuple of 2 elements (-t, +t)" + + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, ('a', 'b'), ep) + assert str(e_info.value) == "windowsize should be a tuple of int/float" + + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), [1,2,3]) + assert str(e_info.value) == "ep should be an IntervalSet object." def test_compute_spike_trigger_average_time_unit(): @@ -123,24 +215,66 @@ def test_compute_spike_trigger_average_time_unit(): assert sta.shape == output.shape np.testing.assert_array_almost_equal(sta.values, output) - -def test_compute_spike_trigger_average_multiple_epochs(): - ep = nap.IntervalSet(0, 101) +@pytest.mark.filterwarnings("ignore") +def test_compute_spike_trigger_average_no_windows(): + ep = nap.IntervalSet(0, 100) feature = pd.Series(index=np.arange(0, 101, 0.01), data=np.zeros(int(101 / 0.01))) - t1 = np.arange(1, 101) + t1 = np.arange(1, 100) feature.loc[t1] = 1.0 - spikes = nap.TsGroup({0: nap.Ts(t1)}, time_support=ep) - - ep2 = nap.IntervalSet(start=[0, 40], end=[10, 60]) + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep + ) feature = nap.Tsd(feature, time_support=ep) - sta = nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep2) + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, ep=ep) - output = np.zeros(int((0.5 / 0.1) * 2 + 1)) - count = spikes[0].count(0.1, ep2).values - feat = feature.bin_average(0.1, ep2).values - output[5] = np.dot(count, feat)/count.sum() + output = np.zeros((1, 3)) + output[0, 0] = 0.05 assert isinstance(sta, nap.TsdFrame) - np.testing.assert_array_almost_equal(sta.values.flatten(), output) + assert sta.shape == output.shape + np.testing.assert_array_almost_equal(sta, output) + + +def test_compute_spike_trigger_average_multiple_epochs(): + ep = nap.IntervalSet(start = [0, 200], end=[100,300]) + feature = nap.Tsd( + t=np.hstack((np.arange(0, 100, 0.001), np.arange(200, 300, 0.001))), + d=np.hstack((np.random.randn(100000), np.random.randn(100000))), + time_support=ep + ) + t1 = np.hstack((np.sort(np.random.uniform(0, 100, 1000)), np.sort(np.random.uniform(200, 300, 1000)))) + spikes = nap.TsGroup( + {0: nap.Ts(t1)}, time_support=ep + ) + + group = spikes + binsize = 0.1 + windowsize = (1.0, 1.0) + + sta = nap.compute_event_trigger_average(spikes, feature, binsize, windowsize, ep) + + start, end = windowsize + idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] + idx2 = np.arange(0, end + binsize, binsize)[1:] + time_idx = np.hstack((idx1, np.zeros(1), idx2)) + from scipy.linalg import hankel + n_p = len(idx1) + n_f = len(idx2) + + sta2 = [] + for i in range(2): + count = group.count(binsize, ep.loc[[i]]) + tmp = feature.bin_average(binsize, ep.loc[[i]]) + + # Build the Hankel matrix + pad_tmp = np.pad(tmp.values, (n_p, n_f)) + offset_tmp = hankel(pad_tmp, pad_tmp[-(n_p + n_f + 1) :])[0 : len(tmp)] + stai = np.dot(offset_tmp.T, count.values) + stai = stai / np.sum(count.values, 0) + sta2.append(stai) + + sta2 = np.hstack(sta2).mean(1) + + np.testing.assert_array_almost_equal(sta.values[:,0], sta2) \ No newline at end of file diff --git a/tests/test_tuning_curves.py b/tests/test_tuning_curves.py index 359bb174..730e5486 100644 --- a/tests/test_tuning_curves.py +++ b/tests/test_tuning_curves.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: gviejo # @Date: 2022-03-30 11:16:30 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-16 12:26:48 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2024-01-26 15:23:20 """Tests of tuning curves for `pynapple` package.""" @@ -36,18 +36,18 @@ def test_compute_discrete_tuning_curves_with_strings(): def test_compute_discrete_tuning_curves_error(): dict_ep = { "0":nap.IntervalSet(start=0, end=50), "1":nap.IntervalSet(start=50, end=100)} - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_discrete_tuning_curves([1,2,3], dict_ep) - assert str(e_info.value) == "Unknown format for group" + assert str(e_info.value) == "group should be a TsGroup." tsgroup = nap.TsGroup({0: nap.Ts(t=np.arange(0, 100))}) dict_ep = { "0":nap.IntervalSet(start=0, end=50), "1":nap.IntervalSet(start=50, end=100)} k = [1,2,3] dict_ep["2"] = k - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_discrete_tuning_curves(tsgroup, dict_ep) - assert str(e_info.value) == "Key 2 in dict_ep is not an IntervalSet" + assert str(e_info.value) == "dict_ep argument should contain only IntervalSet. Key 2 in dict_ep is not an IntervalSet" def test_compute_1d_tuning_curves(): tsgroup = nap.TsGroup({0: nap.Ts(t=np.arange(0, 100))}) @@ -61,9 +61,9 @@ def test_compute_1d_tuning_curves(): def test_compute_1d_tuning_curves_error(): feature = nap.Tsd(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1) % 1.0) - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_1d_tuning_curves([1,2,3], feature, nb_bins=10) - assert str(e_info.value) == "Unknown format for group" + assert str(e_info.value) == "group should be a TsGroup." def test_compute_1d_tuning_curves_with_ep(): tsgroup = nap.TsGroup({0: nap.Ts(t=np.arange(0, 100))}) @@ -110,15 +110,15 @@ def test_compute_2d_tuning_curves_error(): (np.repeat(np.arange(0, 100), 10), np.tile(np.arange(0, 100), 10)) ).T features = nap.TsdFrame(t=np.arange(0, 200, 0.1), d=np.vstack((tmp, tmp[::-1]))) - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_2d_tuning_curves([1,2,3], features, 10) - assert str(e_info.value) == "Unknown format for group" + assert str(e_info.value) == "group should be a TsGroup." tsgroup = nap.TsGroup( {0: nap.Ts(t=np.arange(0, 100, 10)), 1: nap.Ts(t=np.array([50, 149]))} ) features = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3)) - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_2d_tuning_curves(tsgroup, features, 10) assert str(e_info.value) == "feature should have 2 columns only."