diff --git a/CHANGELOG.md b/CHANGELOG.md index da03d58c1..9af02f4e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Refactor `TableChain` to include `_searched` attribute. #867 - Fix errors in config import #882 - Save current spyglass version in analysis nwb files to aid diagnosis #897 +- Add pynapple support #898 ### Pipelines diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 4a0495778..c5fd82276 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -105,6 +105,49 @@ def dj_replace(original_table, new_values, key_column, replace_column): return original_table +def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs): + """Get the NWB file name and path from the given DataJoint query. + + Parameters + ---------- + query_expression : query + A DataJoint query expression (e.g., join, restrict) or a table to call fetch on. + tbl : table + DataJoint table to fetch from. + attr_name : str + Attribute name to fetch from the table. + *attrs : list + Attributes from normal DataJoint fetch call. + **kwargs : dict + Keyword arguments from normal DataJoint fetch call. + + Returns + ------- + nwb_files : list + List of NWB file names. + file_path_fn : function + Function to get the absolute path to the NWB file. + """ + from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile + + kwargs["as_dict"] = True # force return as dictionary + attrs = attrs or query_expression.heading.names # if none, all + + which = "analysis" if "analysis" in attr_name else "nwb" + tbl_map = { # map to file_name_str and file_path_fn + "analysis": ["analysis_file_name", AnalysisNwbfile.get_abs_path], + "nwb": ["nwb_file_name", Nwbfile.get_abs_path], + } + file_name_str, file_path_fn = tbl_map[which] + + # TODO: check that the query_expression restricts tbl - CBroz + nwb_files = ( + query_expression * tbl.proj(nwb2load_filepath=attr_name) + ).fetch(file_name_str) + + return nwb_files, file_path_fn + + def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): """Get an NWB object from the given DataJoint query. @@ -127,29 +170,10 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): nwb_objects : list List of dicts containing fetch results and NWB objects. """ - kwargs["as_dict"] = True # force return as dictionary tbl, attr_name = nwb_master - - if not attrs: - attrs = query_expression.heading.names - - # get the list of analysis or nwb files - file_name_str = ( - "analysis_file_name" if "analysis" in nwb_master[1] else "nwb_file_name" - ) - # TODO: avoid this import? - from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile - - file_path_fn = ( - AnalysisNwbfile.get_abs_path - if "analysis" in nwb_master[1] - else Nwbfile.get_abs_path + nwb_files, file_path_fn = get_nwb_table( + query_expression, tbl, attr_name, *attrs, **kwargs ) - - # TODO: check that the query_expression restricts tbl - CBroz - nwb_files = ( - query_expression * tbl.proj(nwb2load_filepath=attr_name) - ).fetch(file_name_str) for file_name in nwb_files: file_path = file_path_fn(file_name) if not os.path.exists(file_path): diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 29978ae88..082116bf6 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -13,10 +13,15 @@ from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_chains import TableChain, TableChains -from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger +try: + import pynapple # noqa F401 +except ImportError: + pynapple = None + class SpyglassMixin: """Mixin for Spyglass DataJoint tables. @@ -122,6 +127,43 @@ def fetch_nwb(self, *attrs, **kwargs): """ return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs) + def fetch_pynapple(self, *attrs, **kwargs): + """Get a pynapple object from the given DataJoint query. + + Parameters + ---------- + *attrs : list + Attributes from normal DataJoint fetch call. + **kwargs : dict + Keyword arguments from normal DataJoint fetch call. + + Returns + ------- + pynapple_objects : list of pynapple objects + List of dicts containing pynapple objects. + + Raises + ------ + ImportError + If pynapple is not installed. + + """ + if pynapple is None: + raise ImportError("Pynapple is not installed.") + + nwb_files, file_path_fn = get_nwb_table( + self, + self._nwb_table_tuple[0], + self._nwb_table_tuple[1], + *attrs, + **kwargs, + ) + + return [ + pynapple.load_file(file_path_fn(file_name)) + for file_name in nwb_files + ] + # ------------------------ delete_downstream_merge ------------------------ @cached_property