Skip to content

Commit

Permalink
Merge pull request #224 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
gviejo authored Jan 26, 2024
2 parents 9638d73 + 3df418e commit 73424b3
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 121 deletions.
99 changes: 99 additions & 0 deletions draft_pynapple_fastplotlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
"""
Fastplotlib
===========
Working with calcium data.
For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction.
The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it.
See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package.
This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo.
"""
# %%
# !!! warning
# This tutorial uses seaborn and matplotlib for displaying the figure
#
# You can install all with `pip install matplotlib seaborn tqdm`
#
# mkdocs_gallery_thumbnail_number = 1
#
# Now, import the necessary libraries:

# %qui qt

import pynapple as nap
import numpy as np
import fastplotlib as fpl

import imageio.v3 as iio
import sys
# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png'

#nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb")
nwb = nap.load_file("your/path/to/MyProject/sub-A2929/ses-A2929-200711/pynapplenwb/A2929-200711.nwb")

units = nwb['units']#.getby_category("location")['adn']

tmp = units.to_tsd()

tmp = np.vstack((tmp.index.values, tmp.values)).T

# Example 1

fplot = fpl.Plot()

fplot.add_scatter(tmp)

fplot.graphics[0].cmap = "jet"

fplot.graphics[0].cmap.values = tmp[:, 1]

fplot.show(maintain_aspect=False)

# Example 2

names = [['raster'], ['position']]

grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names)

grid_plot['raster'].add_scatter(tmp)

grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T)

grid_plot.show(maintain_aspect=False)

grid_plot['raster'].auto_scale(maintain_aspect=False)


# Example 3
#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi")
#frames = frames[:,:,:,0]
frames = np.random.randn(10, 100, 100)

iw = fpl.ImageWidget(frames, cmap="gnuplot2")

#iw.show()

# Example 4

from PyQt6 import QtWidgets


mainwidget = QtWidgets.QWidget()

hlayout = QtWidgets.QHBoxLayout(mainwidget)

iw.widget.setParent(mainwidget)

hlayout.addWidget(iw.widget)

grid_plot.widget.setParent(mainwidget)

hlayout.addWidget(grid_plot.widget)

mainwidget.show()
114 changes: 101 additions & 13 deletions pynapple/core/jitted_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: guillaume
# @Date: 2022-10-31 16:44:31
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-12-12 16:50:36
# @Last Modified time: 2024-01-25 16:43:34
import numpy as np
from numba import jit, njit, prange

Expand Down Expand Up @@ -866,7 +866,7 @@ def jitcontinuous_perievent(
(np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan
)

if np.all((count[:, 0] * count[:, 1]) > 0):
if np.any((count[:, 0] * count[:, 1]) > 0):
for k in prange(N_epochs):
if count[k, 0] > 0 and count[k, 1] > 0:
t = start_t[k, 0]
Expand All @@ -891,9 +891,9 @@ def jitcontinuous_perievent(
left = np.minimum(windowsize[0], t_pos - start_t[k, 0])
right = np.minimum(windowsize[1], maxt - t_pos - 1)
center = windowsize[0] + 1
new_data_array[
center - left - 1 : center + right, cnt_i
] = data_array[t_pos - left : t_pos + right + 1]
new_data_array[center - left - 1 : center + right, cnt_i] = (
data_array[t_pos - left : t_pos + right + 1]
)

t -= 1
i += 1
Expand All @@ -902,15 +902,103 @@ def jitcontinuous_perievent(
return new_data_array


# time_array = tsd.t
# time_target_array = tref.t
# data_array = tsd.d
@jit(nopython=True)
def jitperievent_trigger_average(
time_array,
count_array,
time_target_array,
data_target_array,
starts,
ends,
windows,
binsize,
):
T = time_array.shape[0]
N = count_array.shape[1]
N_epochs = len(starts)

time_target_array, data_target_array, count = jitrestrict_with_count(
time_target_array, data_target_array, starts, ends
)
max_count = np.cumsum(count)

new_data_array = np.full(
(int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]),
0.0,
)

t = 0 # count events

hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:]))

for k in range(N_epochs):
if count[k] > 0:
t_start = t
maxi = max_count[k]
i = maxi - count[k]

# for i,t in enumerate(tref.restrict(ep).t):
# plot(time_idx + t, new_data_array[:,i]+i*2.0, 'o')
# plot(tsd + i*2.0, color='grey')
# [axvspan(ep.loc[i,'start'], ep.loc[i,'end'], alpha=0.3) for i in range(len(ep))]
# [axvline(t) for t in tref.restrict(ep).t]
while t < T:
lbound = time_array[t]
rbound = np.round(lbound + binsize, 9)

if time_target_array[i] < rbound:
i_start = i
i_stop = i

while i_stop < maxi:
if time_target_array[i_stop] < rbound:
i_stop += 1
else:
break

while i_start < i_stop - 1:
if time_target_array[i_start] < lbound:
i_start += 1
else:
break
v = np.sum(data_target_array[i_start:i_stop], 0) / float(
i_stop - i_start
)

checknan = np.sum(v)
if not np.isnan(checknan):
hankel_array[-1] = v

if t - t_start >= windows[1]:
for n in range(N):
new_data_array[:, n] += (
hankel_array * count_array[t - windows[1], n]
)

# hankel_array = np.roll(hankel_array, -1, axis=0)
hankel_array[0:-1] = hankel_array[1:]
hankel_array[-1] = 0.0

t += 1

i = i_start

if t == T or time_array[t] > ends[k]:
if t - t_start > windows[1]:
for j in range(windows[1]):
for n in range(N):
new_data_array[:, n] += (
hankel_array * count_array[t - windows[1] + j, n]
)

# hankel_array = np.roll(hankel_array, -1, axis=0)
hankel_array[0:-1] = hankel_array[1:]
hankel_array[-1] = 0.0

hankel_array *= 0.0
break

total = np.sum(count_array, 0)
for n in range(N):
if total[n] > 0.0:
new_data_array[:, n] /= total[n]

return new_data_array


# @jit(nopython=True)
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-01-27 18:33:31
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-12-07 13:58:06
# @Last Modified time: 2024-01-08 16:09:01

"""
Expand Down
1 change: 1 addition & 0 deletions pynapple/core/time_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- 's': seconds (overall default)
"""

from warnings import warn

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions pynapple/io/phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@author: Sara Mahallati, Guillaume Viejo
"""

import os

import numpy as np
Expand Down
Loading

0 comments on commit 73424b3

Please sign in to comment.