Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Dec 8, 2023
1 parent 724c504 commit 1f91313
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
15 changes: 8 additions & 7 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,14 @@ def _init_artifact_worker(
"""Create a local dict per worker"""
worker_ctx = {}
if isinstance(recording, dict):
worker_ctx["recording"] = si.load_extractor(recording)
return dict(recording=si.load_extractor(recording))
else:
worker_ctx["recording"] = recording
worker_ctx["zscore_thresh"] = zscore_thresh
worker_ctx["amplitude_thresh_uV"] = amplitude_thresh_uV
worker_ctx["proportion_above_thresh"] = proportion_above_thresh
return worker_ctx
return dict(
recording=recording,
zscore_thresh=zscore_thresh,
amplitude_thresh_uV=amplitude_thresh_uV,
proportion_above_thresh=proportion_above_thresh,
)


def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx):
Expand All @@ -341,7 +342,7 @@ def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx):
)

# find the artifact occurrences using one or both thresholds, across channels
if (amplitude_thresh_uV is not None) and (zscore_thresh is None):
if amplitude_thresh_uV and not zscore_thresh:
above_a = np.abs(traces) > amplitude_thresh_uV
above_thresh = (
np.ravel(np.argwhere(np.sum(above_a, axis=1) >= nelect_above))
Expand Down
10 changes: 5 additions & 5 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def _list_to_merge_dict(
Example
-------
Input: [[1,2,3],[4,5]], [1,2,3,4,5,6]
Input: [[1,2,3],[4,5]], [1,2,3,4,5,6]
Output: {1: [2, 3], 2:[1,3], 3:[1,2] 4: [5], 5: [4], 6: []}
"""
merge_group_list = _union_intersecting_lists(merge_group_list)
Expand All @@ -449,10 +449,10 @@ def _list_to_merge_dict(


def _reverse_associations(assoc_dict):
return [
[key] + values if values else [key]
for key, values in assoc_dict.items()
]
return [
[key] + values if values else [key]
for key, values in assoc_dict.items()
]


def _merge_dict_to_list(merge_groups: dict) -> List:
Expand Down

0 comments on commit 1f91313

Please sign in to comment.