Skip to content

Commit

Permalink
Revert "Add intervallist to exclude artifacts"
Browse files Browse the repository at this point in the history
This reverts commit 205874a.
  • Loading branch information
edeno committed Jan 30, 2024
1 parent 274a89d commit e102646
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/spyglass/ripple/v1/ripple.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class RippleTimesV1(SpyglassMixin, dj.Computed):
-> RippleLFPSelection
-> RippleParameters
-> PositionOutput.proj(pos_merge_id='merge_id')
-> IntervalList.proj(artifact_interval_list_name='interval_list_name')
---
-> AnalysisNwbfile
ripple_times_object_id : varchar(40)
Expand Down
19 changes: 1 addition & 18 deletions src/spyglass/spikesorting/analysis/v1/mua.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from ripple_detection import multiunit_HSE_detector

from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.position import PositionOutput # noqa: F401
from spyglass.spikesorting.analysis.v1.group import (
Expand Down Expand Up @@ -45,7 +44,7 @@ class MuaEventsV1(SpyglassMixin, dj.Computed):
-> MuaEventsParameters
-> SortedSpikesGroup
-> PositionOutput.proj(pos_merge_id='merge_id')
-> IntervalList.proj(artifact_interval_list_name='interval_list_name') # exclude artifact times
---
-> AnalysisNwbfile
mua_times_object_id : varchar(40)
Expand All @@ -71,22 +70,6 @@ def make(self, key):

mua_params = (MuaEventsParameters & key).fetch1("mua_param_dict")

# Exclude artifact times
# Alternatively could set to NaN and leave them out of the firing rate calculation
# in the multiunit_HSE_detector function
artifact_key = {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["artifact_interval_list_name"],
}
artifact_times = (IntervalList & artifact_key).fetch1("valid_times")
mean_n_spikes = np.mean(spike_indicator)
for artifact_time in artifact_times:
spike_indicator[
np.logical_and(
time >= artifact_time.start, time <= artifact_time.stop
)
] = mean_n_spikes

mua_times = multiunit_HSE_detector(
time, spike_indicator, speed, sampling_frequency, **mua_params
)
Expand Down

0 comments on commit e102646

Please sign in to comment.