Skip to content

Commit

Permalink
Add figurl for ripple for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Feb 7, 2024
1 parent d4cf821 commit 63d6abd
Showing 1 changed file with 105 additions and 15 deletions.
120 changes: 105 additions & 15 deletions src/spyglass/ripple/v1/ripple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sortingview.views as vv
from ripple_detection import Karlsson_ripple_detector, Kay_ripple_detector
from ripple_detection.core import gaussian_smooth, get_envelope
from scipy.stats import zscore

from spyglass.common.common_interval import (
IntervalList,
interval_list_intersect,
)
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.lfp.analysis.v1.lfp_band import LFPBandSelection, LFPBandV1
from spyglass.lfp.lfp_merge import LFPOutput
from spyglass.position import PositionOutput
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_electrode_indices
Expand Down Expand Up @@ -155,9 +158,7 @@ class RippleTimesV1(SpyglassMixin, dj.Computed):
"""

def make(self, key):
nwb_file_name, interval_list_name = (LFPBandV1 & key).fetch1(
"nwb_file_name", "target_interval_list_name"
)
nwb_file_name = (LFPBandV1 & key).fetch1("nwb_file_name")

logger.info(f"Computing ripple times for: {key}")
ripple_params = (
Expand All @@ -171,9 +172,7 @@ def make(self, key):
speed,
interval_ripple_lfps,
sampling_frequency,
) = self.get_ripple_lfps_and_position_info(
key, nwb_file_name, interval_list_name
)
) = self.get_ripple_lfps_and_position_info(key)
ripple_times = RIPPLE_DETECTION_ALGORITHMS[ripple_detection_algorithm](
time=np.asarray(interval_ripple_lfps.index),
filtered_lfps=np.asarray(interval_ripple_lfps),
Expand Down Expand Up @@ -203,9 +202,7 @@ def fetch_dataframe(self):
return [data["ripple_times"] for data in self.fetch_nwb()]

@staticmethod
def get_ripple_lfps_and_position_info(
key, nwb_file_name, interval_list_name
):
def get_ripple_lfps_and_position_info(key):
ripple_params = (
RippleParameters & {"ripple_param_name": key["ripple_param_name"]}
).fetch1("ripple_param_dict")
Expand All @@ -219,7 +216,7 @@ def get_ripple_lfps_and_position_info(
# warn/validate that there is only one wire per electrode
ripple_lfp_nwb = (LFPBandV1 & key).fetch_nwb()[0]
ripple_lfp_electrodes = ripple_lfp_nwb["lfp_band"].electrodes.data[:]
elec_mask = np.full_like(ripple_lfp_electrodes, 0, dtype=bool)
elec_mask = np.zeros_like(ripple_lfp_electrodes, dtype=bool)
valid_elecs = [
elec for elec in electrode_keys if elec in ripple_lfp_electrodes
]
Expand All @@ -230,16 +227,14 @@ def get_ripple_lfps_and_position_info(
ripple_lfp = pd.DataFrame(
ripple_lfp_nwb["lfp_band"].data,
index=pd.Index(ripple_lfp_nwb["lfp_band"].timestamps, name="time"),
)
).loc[:, elec_mask]
sampling_frequency = ripple_lfp_nwb["lfp_band_sampling_rate"]

ripple_lfp = ripple_lfp.loc[:, elec_mask]

position_valid_times = (
IntervalList
& {
"nwb_file_name": nwb_file_name,
"interval_list_name": interval_list_name,
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["target_interval_list_name"],
}
).fetch1("valid_times")
position_info = (
Expand Down Expand Up @@ -366,3 +361,98 @@ def plot_ripple(
)
ax.set_ylabel("LFPs")
ax.set_xlabel("Time [s]")

def create_figurl(
self,
zscore_ripple=False,
ripple_times_color="red",
consensus_color="black",
speed_color="black",
view_height=800,
use_ripple_filtered_lfps=False,
lfp_offset=1,
):

ripple_times = self.fetch1_dataframe()
key = self.fetch1("KEY")
(
speed,
ripple_filtered_lfps,
_,
) = self.get_ripple_lfps_and_position_info(key)

if zscore_ripple:
ripple_consensus_trace = zscore(ripple_consensus_trace)

consensus_view = vv.TimeseriesGraph()
_add_ripple_times(consensus_view, ripple_times, ripple_times_color)
consensus_name = (
"Z-Scored Consensus Trace" if zscore_ripple else "Consensus Trace"
)
consensus_view.add_line_series(
name=consensus_name,
t=np.asarray(ripple_consensus_trace.index).squeeze(),
y=np.asarray(ripple_consensus_trace, dtype=np.float32).squeeze(),
color=consensus_color,
width=1,
)

if use_ripple_filtered_lfps:
interval_ripple_lfps = ripple_filtered_lfps
else:
lfp_merge_id = (LFPBandSelection & key).fetch1("lfp_merge_id")
lfp_df = (LFPOutput & {"merge_id": lfp_merge_id}).fetch1_dataframe()
interval_ripple_lfps = lfp_df.loc[speed.index[0] : speed.index[-1]]

lfp_view = vv.TimeseriesGraph()
_add_ripple_times(lfp_view, ripple_times, ripple_times_color)
max_lfp_value = interval_ripple_lfps.to_numpy().max()
lfp_offset *= max_lfp_value

for i, lfp in enumerate(interval_ripple_lfps.to_numpy().T):
lfp_view.add_line_series(
name=f"LFP {i}",
t=np.asarray(interval_ripple_lfps.index).squeeze(),
y=np.asarray(lfp + lfp_offset * i, dtype=np.int16).squeeze(),
color="black",
width=1,
)

speed_view = vv.TimeseriesGraph().add_line_series(
name="Speed [cm/s]",
t=np.asarray(speed.index).squeeze(),
y=np.asarray(speed, dtype=np.float32).squeeze(),
color=speed_color,
width=1,
)
vertical_panel_content = [
vv.LayoutItem(consensus_view, stretch=2, title="Consensus"),
vv.LayoutItem(lfp_view, stretch=8, title="LFPs"),
vv.LayoutItem(speed_view, stretch=2, title="Speed"),
]

view = vv.Box(
direction="horizontal",
show_titles=True,
height=view_height,
items=[
vv.LayoutItem(
vv.Box(
direction="vertical",
show_titles=True,
items=vertical_panel_content,
)
),
],
)

return view.url(label="Ripple Detection")


def _add_ripple_times(view, ripple_times, ripple_times_color):
view.add_interval_series(
name="Ripple Events",
t_start=ripple_times.start_time.to_numpy(),
t_end=ripple_times.end_time.to_numpy(),
color=ripple_times_color,
)

0 comments on commit 63d6abd

Please sign in to comment.