Skip to content

Commit

Permalink
Add pynapple support (#898)
Browse files Browse the repository at this point in the history
* Preliminary code

* Add retrieval of file names

* Add get_nwb_table function

* Update docstrings

* Update CHANGELOG.md
  • Loading branch information
edeno authored Mar 29, 2024
1 parent 309bde5 commit 65745f8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Refactor `TableChain` to include `_searched` attribute. #867
- Fix errors in config import #882
- Save current spyglass version in analysis nwb files to aid diagnosis #897
- Add pynapple support #898

### Pipelines

Expand Down
66 changes: 45 additions & 21 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,49 @@ def dj_replace(original_table, new_values, key_column, replace_column):
return original_table


def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
"""Get the NWB file name and path from the given DataJoint query.
Parameters
----------
query_expression : query
A DataJoint query expression (e.g., join, restrict) or a table to call fetch on.
tbl : table
DataJoint table to fetch from.
attr_name : str
Attribute name to fetch from the table.
*attrs : list
Attributes from normal DataJoint fetch call.
**kwargs : dict
Keyword arguments from normal DataJoint fetch call.
Returns
-------
nwb_files : list
List of NWB file names.
file_path_fn : function
Function to get the absolute path to the NWB file.
"""
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile

kwargs["as_dict"] = True # force return as dictionary
attrs = attrs or query_expression.heading.names # if none, all

which = "analysis" if "analysis" in attr_name else "nwb"
tbl_map = { # map to file_name_str and file_path_fn
"analysis": ["analysis_file_name", AnalysisNwbfile.get_abs_path],
"nwb": ["nwb_file_name", Nwbfile.get_abs_path],
}
file_name_str, file_path_fn = tbl_map[which]

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression * tbl.proj(nwb2load_filepath=attr_name)
).fetch(file_name_str)

return nwb_files, file_path_fn


def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs):
"""Get an NWB object from the given DataJoint query.
Expand All @@ -127,29 +170,10 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs):
nwb_objects : list
List of dicts containing fetch results and NWB objects.
"""
kwargs["as_dict"] = True # force return as dictionary
tbl, attr_name = nwb_master

if not attrs:
attrs = query_expression.heading.names

# get the list of analysis or nwb files
file_name_str = (
"analysis_file_name" if "analysis" in nwb_master[1] else "nwb_file_name"
)
# TODO: avoid this import?
from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile

file_path_fn = (
AnalysisNwbfile.get_abs_path
if "analysis" in nwb_master[1]
else Nwbfile.get_abs_path
nwb_files, file_path_fn = get_nwb_table(
query_expression, tbl, attr_name, *attrs, **kwargs
)

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression * tbl.proj(nwb2load_filepath=attr_name)
).fetch(file_name_str)
for file_name in nwb_files:
file_path = file_path_fn(file_name)
if not os.path.exists(file_path):
Expand Down
44 changes: 43 additions & 1 deletion src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.utils.dj_chains import TableChain, TableChains
from spyglass.utils.dj_helper_fn import fetch_nwb
from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.logging import logger

try:
import pynapple # noqa F401
except ImportError:
pynapple = None


class SpyglassMixin:
"""Mixin for Spyglass DataJoint tables.
Expand Down Expand Up @@ -122,6 +127,43 @@ def fetch_nwb(self, *attrs, **kwargs):
"""
return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs)

def fetch_pynapple(self, *attrs, **kwargs):
"""Get a pynapple object from the given DataJoint query.
Parameters
----------
*attrs : list
Attributes from normal DataJoint fetch call.
**kwargs : dict
Keyword arguments from normal DataJoint fetch call.
Returns
-------
pynapple_objects : list of pynapple objects
List of dicts containing pynapple objects.
Raises
------
ImportError
If pynapple is not installed.
"""
if pynapple is None:
raise ImportError("Pynapple is not installed.")

nwb_files, file_path_fn = get_nwb_table(
self,
self._nwb_table_tuple[0],
self._nwb_table_tuple[1],
*attrs,
**kwargs,
)

return [
pynapple.load_file(file_path_fn(file_name))
for file_name in nwb_files
]

# ------------------------ delete_downstream_merge ------------------------

@cached_property
Expand Down

0 comments on commit 65745f8

Please sign in to comment.