Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pynapple support #898

Merged
merged 6 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Refactor `TableChain` to include `_searched` attribute. #867
- Fix errors in config import #882
- Save current spyglass version in analysis nwb files to aid diagnosis #897
- Add pynapple support #898

### Pipelines

Expand Down
66 changes: 45 additions & 21 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,49 @@ def dj_replace(original_table, new_values, key_column, replace_column):
return original_table


def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
"""Get the NWB file name and path from the given DataJoint query.

Parameters
----------
query_expression : query
A DataJoint query expression (e.g., join, restrict) or a table to call fetch on.
tbl : table
DataJoint table to fetch from.
attr_name : str
Attribute name to fetch from the table.
*attrs : list
Attributes from normal DataJoint fetch call.
**kwargs : dict
Keyword arguments from normal DataJoint fetch call.

Returns
-------
nwb_files : list
List of NWB file names.
file_path_fn : function
Function to get the absolute path to the NWB file.
"""
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile

kwargs["as_dict"] = True # force return as dictionary
attrs = attrs or query_expression.heading.names # if none, all

which = "analysis" if "analysis" in attr_name else "nwb"
tbl_map = { # map to file_name_str and file_path_fn
"analysis": ["analysis_file_name", AnalysisNwbfile.get_abs_path],
"nwb": ["nwb_file_name", Nwbfile.get_abs_path],
}
file_name_str, file_path_fn = tbl_map[which]

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression * tbl.proj(nwb2load_filepath=attr_name)
).fetch(file_name_str)

return nwb_files, file_path_fn


def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs):
"""Get an NWB object from the given DataJoint query.

Expand All @@ -127,29 +170,10 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs):
nwb_objects : list
List of dicts containing fetch results and NWB objects.
"""
kwargs["as_dict"] = True # force return as dictionary
tbl, attr_name = nwb_master

if not attrs:
attrs = query_expression.heading.names

# get the list of analysis or nwb files
file_name_str = (
"analysis_file_name" if "analysis" in nwb_master[1] else "nwb_file_name"
)
# TODO: avoid this import?
from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile

file_path_fn = (
AnalysisNwbfile.get_abs_path
if "analysis" in nwb_master[1]
else Nwbfile.get_abs_path
nwb_files, file_path_fn = get_nwb_table(
query_expression, tbl, attr_name, *attrs, **kwargs
)

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression * tbl.proj(nwb2load_filepath=attr_name)
).fetch(file_name_str)
for file_name in nwb_files:
file_path = file_path_fn(file_name)
if not os.path.exists(file_path):
Expand Down
44 changes: 43 additions & 1 deletion src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.utils.dj_chains import TableChain, TableChains
from spyglass.utils.dj_helper_fn import fetch_nwb
from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.logging import logger

try:
import pynapple # noqa F401
except ImportError:
pynapple = None


class SpyglassMixin:
"""Mixin for Spyglass DataJoint tables.
Expand Down Expand Up @@ -122,6 +127,43 @@ def fetch_nwb(self, *attrs, **kwargs):
"""
return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs)

def fetch_pynapple(self, *attrs, **kwargs):
"""Get a pynapple object from the given DataJoint query.

Parameters
----------
*attrs : list
Attributes from normal DataJoint fetch call.
**kwargs : dict
Keyword arguments from normal DataJoint fetch call.

Returns
-------
pynapple_objects : list of pynapple objects
List of dicts containing pynapple objects.

Raises
------
ImportError
If pynapple is not installed.

"""
if pynapple is None:
raise ImportError("Pynapple is not installed.")

nwb_files, file_path_fn = get_nwb_table(
self,
self._nwb_table_tuple[0],
self._nwb_table_tuple[1],
Comment on lines +156 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be swapped out for the following, but it's fine as is

Suggested change
self._nwb_table_tuple[0],
self._nwb_table_tuple[1],
*self._nwb_table_tuple,

*attrs,
**kwargs,
)

return [
pynapple.load_file(file_path_fn(file_name))
for file_name in nwb_files
]

# ------------------------ delete_downstream_merge ------------------------

@cached_property
Expand Down
Loading