Skip to content

Commit

Permalink
Fix changes
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Dec 9, 2023
1 parent 7003c07 commit b292e8c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 44 deletions.
48 changes: 43 additions & 5 deletions notebooks/10_Spike_SortingV1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,14 @@
"sgs.CurationV1()"
]
},
{
"cell_type": "markdown",
"id": "9ff6aff5-7020-40d6-832f-006d66d54a7e",
"metadata": {},
"source": [
"We now insert the curated spike sorting to a `Merge` table for feeding into downstream processing pipelines."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -731,33 +739,63 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a59b78e2-b76a-4958-b012-708ed5600be6",
"id": "5047f866-7435-4dea-9ed8-a9b2d8365682",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput.insert_from_source(key, source=\"CurationV1\")"
"SpikeSortingOutput()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "582a40d9-ffb0-401c-ae9d-287336e47911",
"id": "d2702410-01e1-4af0-a987-891c42c6c099",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput()"
"key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b20c2c9e-0c97-4669-b45d-4b1c50fd2fcc",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput.insert([key], part_name='CurationV1')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55b2b015-e9a3-4ab1-b2bc-4925bd8e9438",
"id": "184c3401-8df3-46f0-9dd0-c9fa98395c34",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput.merge_view()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2b083a5-b700-438a-8a06-2e2eb041072d",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput.CurationV1()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10b8afa1-d4a6-4ac1-959b-f4e84e582f2e",
"metadata": {},
"outputs": [],
"source": [
"SpikeSortingOutput.CuratedSpikeSorting()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@schema
class SpikeSortingOutput(_Merge):
definition = """
# Output of spike sorting pipelines. Use `insert_from_source` method to insert rows.
# Output of spike sorting pipelines.
merge_id: uuid
---
source: varchar(32)
Expand Down
26 changes: 13 additions & 13 deletions src/spyglass/spikesorting/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from .recording import (
SortGroup,
SpikeSortingPreprocessingParameters,
SpikeSortingRecordingSelection,
SpikeSortingRecording,
)
from .artifact import (
ArtifactDetection,
ArtifactDetectionParameters,
ArtifactDetectionSelection,
ArtifactDetection,
)
from .sorting import SpikeSorterParameters, SpikeSortingSelection, SpikeSorting
from .curation import CurationV1
from .figurl_curation import FigURLCuration, FigURLCurationSelection
from .metric_curation import (
WaveformParameters,
MetricParameters,
MetricCuration,
MetricCurationParameters,
MetricCurationSelection,
MetricCuration,
MetricParameters,
WaveformParameters,
)
from .recording import (
SortGroup,
SpikeSortingPreprocessingParameters,
SpikeSortingRecording,
SpikeSortingRecordingSelection,
)
from .figurl_curation import FigURLCurationSelection, FigURLCuration
from .utils import get_spiking_v1_merge_ids
from .sorting import SpikeSorterParameters, SpikeSorting, SpikeSortingSelection
from .utils import get_spiking_sorting_v1_merge_ids
35 changes: 18 additions & 17 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _get_artifact_times(
# turn ms to remove total into s to remove from either side of each detected artifact
half_removal_window_s = removal_window_ms / 2 / 1000

if not artifact_frames:
if len(artifact_frames) == 0:
recording_interval = np.asarray(
[[valid_timestamps[0], valid_timestamps[-1]]]
)
Expand Down Expand Up @@ -312,16 +312,16 @@ def _init_artifact_worker(
amplitude_thresh_uV=None,
proportion_above_thresh=1.0,
):
"""Create a local dict per worker"""
# create a local dict per worker
worker_ctx = {}
if isinstance(recording, dict):
return dict(recording=si.load_extractor(recording))
worker_ctx["recording"] = si.load_extractor(recording)
else:
return dict(
recording=recording,
zscore_thresh=zscore_thresh,
amplitude_thresh_uV=amplitude_thresh_uV,
proportion_above_thresh=proportion_above_thresh,
)
worker_ctx["recording"] = recording
worker_ctx["zscore_thresh"] = zscore_thresh
worker_ctx["amplitude_thresh_uV"] = amplitude_thresh_uV
worker_ctx["proportion_above_thresh"] = proportion_above_thresh
return worker_ctx


def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx):
Expand All @@ -341,7 +341,7 @@ def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx):
)

# find the artifact occurrences using one or both thresholds, across channels
if amplitude_thresh_uV and not zscore_thresh:
if (amplitude_thresh_uV is not None) and (zscore_thresh is None):
above_a = np.abs(traces) > amplitude_thresh_uV
above_thresh = (
np.ravel(np.argwhere(np.sum(above_a, axis=1) >= nelect_above))
Expand Down Expand Up @@ -393,13 +393,14 @@ def _check_artifact_thresholds(
ValueError: if signal thresholds are negative
"""
# amplitude or zscore thresholds should be negative, as they are applied to an absolute signal
def is_negative(value):
return value < 0 if value is not None else False

if is_negative(amplitude_thresh_uV) or is_negative(zscore_thresh):
raise ValueError(
"Amplitude and Z-Score thresholds must be >= 0, or None"
)
signal_thresholds = [
t for t in [amplitude_thresh_uV, zscore_thresh] if t is not None
]
for t in signal_thresholds:
if t < 0:
raise ValueError(
"Amplitude and Z-Score thresholds must be >= 0, or None"
)

# proportion_above_threshold should be in [0:1] inclusive
if proportion_above_thresh < 0:
Expand Down
15 changes: 8 additions & 7 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def insert_curation(
"parent_curation_id": parent_curation_id,
"analysis_file_name": analysis_file_name,
"object_id": object_id,
"merges_applied": str(apply_merge),
"merges_applied": apply_merge,
"description": description,
}
cls.insert1(
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_recording(cls, key: dict) -> si.BaseRecording:
primary key of CurationV1 table
"""

analysis_file_abs_path = (
analysis_file_name = (
SpikeSortingRecording * SpikeSortingSelection & key
).fetch1("analysis_file_name")
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
Expand Down Expand Up @@ -343,11 +343,12 @@ def _write_sorting_to_nwb_with_curation(
)
# add labels, merge groups, metrics
if labels is not None:
label_values = [
labels.get(unit_id, [])
for unit_id in unit_ids
if unit_id in labels
]
label_values = []
for unit_id in unit_ids:
if unit_id not in labels:
label_values.append([])
else:
label_values.append(labels[unit_id])
nwbf.add_unit_column(
name="curation_label",
description="curation label",
Expand Down
1 change: 0 additions & 1 deletion src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ def _consolidate_intervals(intervals, timestamps):

# Loop through the rest of the intervals to join them if needed
for next_start, next_stop in zip(start_indices, stop_indices):

# If the stop time of the current interval is equal to or greater than the next start time minus 1
if stop >= next_start - 1:
stop = max(
Expand Down

0 comments on commit b292e8c

Please sign in to comment.