Skip to content

Commit

Permalink
Merge branch 'master' into pr/khl02007/852
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Mar 7, 2024
2 parents aa4d446 + 5a6c96b commit 613174c
Show file tree
Hide file tree
Showing 42 changed files with 1,326 additions and 242 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Infrastructure

- Add user roles to `database_settings.py`. #832
- Revise `dj_chains` to permit undirected paths for paths with multiple Merge
Tables. #846

### Pipelines

Expand All @@ -14,7 +16,9 @@
- Fixes to `_convert_mp4` #834
- Replace deprecated calls to `yaml.safe_load()` #834

- Increase the required `spikeinterface` version to >=0.99.1 for `get_sorting` method associated with `SpikeSorting` and `CurationV1` tables in spike sorting V1 pipeline. Limit version to <0.100 in case there are other issues with it. #852
- Spikesorting:
- Increase the required `spikeinterface` version to >=0.99.1 for `get_sorting` method associated with `SpikeSorting` and `CurationV1` tables in spike sorting V1 pipeline. Limit version to <0.100 in case there are other issues with it. #852
- Bug fix in single artifact interval edge case #859

## [0.5.0] (February 9, 2024)

Expand Down
2 changes: 1 addition & 1 deletion config/add_dj_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
"This script is deprecated. "
+ "Use spyglass.utils.database_settings.DatabaseSettings instead."
)
DatabaseSettings(user_name=sys.argv[1]).add_dj_user()
DatabaseSettings(user_name=sys.argv[1]).add_user(check_exists=True)
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
"--quiet-spy", # don't show logging from spyglass
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
Expand All @@ -146,13 +146,15 @@ omit = [ # which submodules have no tests
"*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
"*/linearization/*",
# "*/linearization/*",
"*/lock/*",
"*/position/*",
"*/mua/*",
# "*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
# "*/utils/*",
"settings.py",
]

[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def populate(self, keys=None):
"""
if not isinstance(keys, list):
keys = [keys]
if isinstance(keys[0], dj.Table):
if isinstance(keys[0], (dj.Table, dj.expression.QueryExpression)):
keys = [k for tbl in keys for k in tbl.fetch("KEY", as_dict=True)]
for key in keys:
nwb_file_name = key.get("nwb_file_name")
Expand All @@ -60,10 +60,10 @@ def populate(self, keys=None):
"PositionSource.populate is an alias for a non-computed table "
+ "and must be passed a key with nwb_file_name"
)
self.insert_from_nwbfile(nwb_file_name)
self.insert_from_nwbfile(nwb_file_name, skip_duplicates=True)

@classmethod
def insert_from_nwbfile(cls, nwb_file_name):
def insert_from_nwbfile(cls, nwb_file_name, skip_duplicates=False) -> None:
"""Add intervals to ItervalList and PositionSource.
Given an NWB file name, get the spatial series and interval lists from
Expand Down Expand Up @@ -111,9 +111,11 @@ def insert_from_nwbfile(cls, nwb_file_name):
)

with cls.connection.transaction:
IntervalList.insert(intervals)
cls.insert(sources)
cls.SpatialSeries.insert(spat_series)
IntervalList.insert(intervals, skip_duplicates=skip_duplicates)
cls.insert(sources, skip_duplicates=skip_duplicates)
cls.SpatialSeries.insert(
spat_series, skip_duplicates=skip_duplicates
)

# make map from epoch intervals to position intervals
populate_position_interval_map_session(nwb_file_name)
Expand Down Expand Up @@ -305,7 +307,7 @@ def make(self, key):
"Unable to import StateScriptFile: no processing module named "
+ '"associated_files" found in {nwb_file_name}.'
)
return
return # See #849

for associated_file_obj in associated_files.data_interfaces.values():
if not isinstance(
Expand Down Expand Up @@ -545,7 +547,7 @@ def _no_transaction_make(self, key):
# Check that each pos interval was matched to only one epoch
if len(matching_pos_intervals) != 1:
# TODO: Now that populate_all accept errors, raise here?
logger.error(
logger.warning(
f"Found {len(matching_pos_intervals)} pos intervals for {key}; "
+ f"{no_pop_msg}\n{matching_pos_intervals}"
)
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make(self, key):
"No conforming behavioral events data interface found in "
+ f"{nwb_file_name}\n"
)
return
return # See #849

# Times for these events correspond to the valid times for the raw data
key["interval_list_name"] = (
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/common/common_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def create_from_config(cls, nwb_file_name: str):
nwbf = get_nwb_file(nwb_file_abspath)
config = get_config(nwb_file_abspath)
if "Electrode" not in config:
return
return # See #849

# map electrode id to dictof electrode information from config YAML
electrode_dicts = {
Expand Down Expand Up @@ -341,7 +341,7 @@ def make(self, key):
"Unable to import SampleCount: no data interface named "
+ f'"sample_count" found in {nwb_file_name}.'
)
return
return # see #849
key["sample_count_object_id"] = sample_count.object_id
self.insert1(key)

Expand Down
24 changes: 16 additions & 8 deletions src/spyglass/common/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@


def hilbert_decomp(lfp_band_object, sampling_rate=1):
"""generates the analytical decomposition of the signals in the lfp_band_object
:param lfp_band_object: bandpass filtered LFP
:type lfp_band_object: pynwb electrical series
:param sampling_rate: bandpass filtered LFP sampling rate (defaults to 1; only used for instantaneous frequency)
:type sampling_rate: int
:return: envelope, phase, frequency
:rtype: pynwb electrical series objects
"""Generates analytical decomposition of signals in the lfp_band_object
NOTE: This function is not currently used in the pipeline.
Parameters
----------
lfp_band_object : pynwb.ecephys.ElectricalSeries
bandpass filtered LFP
sampling_rate : int, optional
bandpass filtered LFP sampling rate
(defaults to 1; only used for instantaneous frequency)
Returns
-------
envelope : pynwb.ecephys.ElectricalSeries
envelope of the signal
"""
analytical_signal = signal.hilbert(lfp_band_object.data, axis=0)

Expand Down
44 changes: 21 additions & 23 deletions src/spyglass/lfp/analysis/v1/lfp_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,24 @@ def set_lfp_band_electrodes(
available_electrodes = query.fetch("electrode_id")
if not np.all(np.isin(electrode_list, available_electrodes)):
raise ValueError(
"All elements in electrode_list must be valid electrode_ids in the LFPElectodeGroup table"
"All elements in electrode_list must be valid electrode_ids in"
+ " the LFPElectodeGroup table: "
+ f"{electrode_list} not in {available_electrodes}"
)
# sampling rate
lfp_sampling_rate = LFPOutput.merge_get_parent(lfp_key).fetch1(
"lfp_sampling_rate"
)
decimation = lfp_sampling_rate // lfp_band_sampling_rate
if lfp_sampling_rate // decimation != lfp_band_sampling_rate:
raise ValueError(
f"lfp_band_sampling rate {lfp_band_sampling_rate} is not an integer divisor of lfp "
f"samping rate {lfp_sampling_rate}"
)
# filter
filter_query = FirFilterParameters() & {
"filter_name": filter_name,
"filter_sampling_rate": lfp_sampling_rate,
}
if not filter_query:
raise ValueError(
f"filter {filter_name}, sampling rate {lfp_sampling_rate} is not in the FirFilterParameters table"
f"Filter {filter_name}, sampling rate {lfp_sampling_rate} is "
+ "not in the FirFilterParameters table"
)
# interval_list
interval_query = IntervalList() & {
Expand All @@ -108,22 +106,23 @@ def set_lfp_band_electrodes(
}
if not interval_query:
raise ValueError(
f"interval list {interval_list_name} is not in the IntervalList table; the list must be "
"added before this function is called"
f"interval list {interval_list_name} is not in the IntervalList"
" table; the list must be added before this function is called"
)
# reference_electrode_list
if len(reference_electrode_list) != 1 and len(
reference_electrode_list
) != len(electrode_list):
raise ValueError(
"reference_electrode_list must contain either 1 or len(electrode_list) elements"
"reference_electrode_list must contain either 1 or "
+ "len(electrode_list) elements"
)
# add a -1 element to the list to allow for the no reference option
available_electrodes = np.append(available_electrodes, [-1])
if not np.all(np.isin(reference_electrode_list, available_electrodes)):
raise ValueError(
"All elements in reference_electrode_list must be valid electrode_ids in the LFPSelection "
"table"
"All elements in reference_electrode_list must be valid "
"electrode_ids in the LFPSelection table"
)

# make a list of all the references
Expand Down Expand Up @@ -204,8 +203,8 @@ def make(self, key):
"interval_list_name": interval_list_name,
}
).fetch1("valid_times")
# the valid_times for this interval may be slightly beyond the valid times for the lfp itself,
# so we have to intersect the two lists
# the valid_times for this interval may be slightly beyond the valid
# times for the lfp itself, so we have to intersect the two lists
lfp_valid_times = (
IntervalList()
& {
Expand All @@ -228,7 +227,8 @@ def make(self, key):

# load in the timestamps
timestamps = np.asarray(lfp_object.timestamps)
# get the indices of the first timestamp and the last timestamp that are within the valid times
# get the indices of the first timestamp and the last timestamp that
# are within the valid times
included_indices = interval_list_contains_ind(
lfp_band_valid_times, timestamps
)
Expand Down Expand Up @@ -267,11 +267,6 @@ def make(self, key):
& {"filter_name": filter_name}
& {"filter_sampling_rate": filter_sampling_rate}
).fetch(as_dict=True)
if len(filter) == 0:
raise ValueError(
f"Filter {filter_name} and sampling_rate {lfp_band_sampling_rate} does not exit in the "
"FirFilterParameters table"
)

filter_coeff = filter[0]["filter_coeff"]
if len(filter_coeff) == 0:
Expand Down Expand Up @@ -378,7 +373,9 @@ def fetch1_dataframe(self, *attrs, **kwargs):
)

def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
"""Computes the hilbert transform of a given LFPBand signal using scipy.signal.hilbert
"""Computes the hilbert transform of a given LFPBand signal
Uses scipy.signal.hilbert to compute the hilbert transform of the signal
Parameters
----------
Expand All @@ -393,7 +390,7 @@ def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
Raises
------
ValueError
If any electrodes passed to electrode_list are invalid for the dataset
If items in electrode_list are invalid for the dataset
"""

filtered_band = self.fetch_nwb()[0]["lfp_band"]
Expand All @@ -402,7 +399,8 @@ def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
)
if len(electrode_list) != np.sum(electrode_index):
raise ValueError(
"Some of the electrodes specified in electrode_list are missing in the current LFPBand table."
"Some of the electrodes specified in electrode_list are missing"
+ " in the current LFPBand table."
)
analytic_signal_df = pd.DataFrame(
hilbert(filtered_band.data[:, electrode_index], axis=0),
Expand Down
8 changes: 4 additions & 4 deletions src/spyglass/lfp/lfp_imported.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datajoint as dj

from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_session import Session
from spyglass.lfp.lfp_electrode import LFPElectrodeGroup
from spyglass.common.common_interval import IntervalList # noqa: F401
from spyglass.common.common_nwbfile import AnalysisNwbfile # noqa: F401
from spyglass.common.common_session import Session # noqa: F401
from spyglass.lfp.lfp_electrode import LFPElectrodeGroup # noqa: F401
from spyglass.utils.dj_mixin import SpyglassMixin

schema = dj.schema("lfp_imported")
Expand Down
7 changes: 4 additions & 3 deletions src/spyglass/lfp/v1/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def make(self, key):
"target_sampling_rate"
)

# to get the list of valid times, we need to combine those from the user with those from the
# raw data
# to get the list of valid times, we need to combine those from the
# user with those from the raw data
orig_key = copy.deepcopy(key)
orig_key["interval_list_name"] = key["target_interval_list_name"]
user_valid_times = (IntervalList() & orig_key).fetch1("valid_times")
Expand Down Expand Up @@ -120,7 +120,8 @@ def make(self, key):
"LFP: no filter found with data sampling rate of "
+ f"{sampling_rate}"
)
return None
return None # See #849

# get the list of selected LFP Channels from LFPElectrode
electrode_keys = (LFPElectrodeGroup.LFPElectrode & key).fetch("KEY")
electrode_id_list = list(k["electrode_id"] for k in electrode_keys)
Expand Down
4 changes: 3 additions & 1 deletion src/spyglass/lfp/v1/lfp_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
class LFPArtifactDetectionParameters(SpyglassMixin, dj.Manual):
definition = """
# Parameters for detecting LFP artifact times within a LFP group.
artifact_params_name: varchar(200)
artifact_params_name: varchar(64)
---
artifact_params: blob # dictionary of parameters
"""

# See #630, #664. Excessive key length.

def insert_default(self):
"""Insert the default artifact parameters."""
diff_params = [
Expand Down
Loading

0 comments on commit 613174c

Please sign in to comment.