Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address join-compatibility issue for long chains #811

Merged
merged 12 commits into from
Jan 31, 2024
12 changes: 9 additions & 3 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from spyglass.common.common_lab import Institution, Lab, LabMember
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_subject import Subject
from spyglass.settings import config, debug_mode
from spyglass.settings import config, debug_mode, test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file

Expand Down Expand Up @@ -214,6 +214,8 @@ def add_session_to_group(
*,
skip_duplicates: bool = False,
):
if test_mode:
skip_duplicates = True
SessionGroupSession.insert1(
{
"session_group_name": session_group_name,
Expand All @@ -230,12 +232,16 @@ def remove_session_from_group(
"session_group_name": session_group_name,
"nwb_file_name": nwb_file_name,
}
(SessionGroupSession & query).delete(*args, **kwargs)
(SessionGroupSession & query).delete(
force_permission=test_mode, *args, **kwargs
)

@staticmethod
def delete_group(session_group_name: str, *args, **kwargs):
query = {"session_group_name": session_group_name}
(SessionGroup & query).delete(*args, **kwargs)
(SessionGroup & query).delete(
force_permission=test_mode, *args, **kwargs
)

@staticmethod
def get_group_sessions(session_group_name: str):
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ class CautiousDelete(dj.Manual):
dj_user: varchar(64)
duration: float
origin: varchar(64)
restriction: varchar(64)
restriction: varchar(255)
merge_deletes = null: blob
"""
101 changes: 88 additions & 13 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import List
from typing import List, Union

import datajoint as dj
import networkx as nx
Expand All @@ -16,6 +16,38 @@ class TableChains:

Functions as a plural version of TableChain, allowing a single `join`
call across all chains from parent -> Merge table.

Attributes
----------
parent : Table
Parent or origin of chains.
child : Table
Merge table or destination of chains.
connection : datajoint.Connection, optional
Connection to database used to create FreeTable objects. Defaults to
parent.connection.
part_names : List[str]
List of full table names of child parts.
chains : List[TableChain]
List of TableChain objects for each part in child.
has_link : bool
Cached attribute to store whether parent is linked to child via any of
child parts. False if (a) child is not in parent.descendants or (b)
nx.NetworkXNoPath is raised by nx.shortest_path for all chains.

Methods
-------
__init__(parent, child, connection=None)
Initialize TableChains with parent and child tables.
__repr__()
Return full representation of chains.
Multiline parent -> child for each chain.
__len__()
Return number of chains with links.
__getitem__(index: Union[int, str])
Return TableChain object at index, or use substring of table name.
join(restriction: str = None)
Return list of joins for each chain in self.chains.
"""

def __init__(self, parent, child, connection=None):
Expand All @@ -33,6 +65,14 @@ def __repr__(self):
def __len__(self):
return len([c for c in self.chains if c.has_link])

def __getitem__(self, index: Union[int, str]):
"""Return FreeTable object at index."""
if isinstance(index, str):
for i, part in enumerate(self.part_names):
if index in part:
return self.chains[i]
return self.chains[index]

def join(self, restriction=None) -> List[QueryExpression]:
"""Return list of joins for each chain in self.chains."""
restriction = restriction or self.parent.restriction or True
Expand Down Expand Up @@ -79,6 +119,8 @@ class TableChain:
Return full representation of chain: parent -> {links} -> child.
__len__()
Return number of tables in chain.
__getitem__(index: Union[int, str])
Return FreeTable object at index, or use substring of table name.
join(restriction: str = None)
Return join of tables in chain with restriction applied to parent.
"""
Expand All @@ -98,16 +140,14 @@ def __init__(self, parent: Table, child: Table, connection=None):
self.parent = parent
self.child = child
self._has_link = child.full_table_name in parent.descendants()
self._errors = []

def __str__(self):
"""Return string representation of chain: parent -> child."""
if not self._has_link:
return "No link"
return (
"Chain: "
+ self.parent.table_name
+ self._link_symbol
+ self.child.table_name
self.parent.table_name + self._link_symbol + self.child.table_name
)

def __repr__(self):
Expand All @@ -123,6 +163,14 @@ def __len__(self):
"""Return number of tables in chain."""
return len(self.names)

def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
"""Return FreeTable object at index."""
if isinstance(index, str):
for i, name in enumerate(self.names):
if index in name:
return self.objects[i]
return self.objects[index]

@property
def has_link(self) -> bool:
"""Return True if parent is linked to child.
Expand All @@ -132,20 +180,34 @@ def has_link(self) -> bool:
"""
return self._has_link

def pk_link(self, src, trg, data) -> float:
"""Return 1 if data["primary"] else float("inf").

Currently unused. Preserved for future debugging."""
return 1 if data["primary"] else float("inf")

@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain.

Uses networkx.shortest_path.
Uses networkx.shortest_path. Ignores numeric table names, which are
'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph
source code for comments on alias nodes.
"""
if not self._has_link:
return None
try:
return nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
)
return [
name
for name in nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
# weight: optional callable to determine edge weight
# weight=self.pk_link,
)
if not name.isdigit()
edeno marked this conversation as resolved.
Show resolved Hide resolved
]
except nx.NetworkXNoPath:
self._has_link = False
return None
Expand All @@ -159,10 +221,23 @@ def objects(self) -> List[dj.FreeTable]:
else None
)

def errors(self) -> List[str]:
"""Return list of errors for each table in chain."""
return self._errors

def join(self, restricton: str = None) -> dj.expression.QueryExpression:
"""Return join of tables in chain with restriction applied to parent."""
if not self._has_link:
return None
restriction = restricton or self.parent.restriction or True
join = self.objects[0] & restriction
for table in self.objects[1:]:
join = join * table
return join if join else None
try:
join = join.proj() * table
except dj.DataJointError as e:
attribute = str(e).split("attribute ")[-1]
logger.error(
f"{str(self)} at {table.table_name} with {attribute}"
)
return None
return join
30 changes: 22 additions & 8 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]:
merge_chains[name] = chains
return merge_chains

def _get_chain(self, substring) -> TableChains:
"""Return chain from self to merge table with substring in name."""
for name, chain in self._merge_chains.items():
if substring.lower() in name:
return chain

def _commit_merge_deletes(
self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs
) -> None:
Expand Down Expand Up @@ -206,8 +212,9 @@ def delete_downstream_merge(

if not merge_join_dict and not disable_warning:
logger.warning(
f"No merge tables found downstream of {self.full_table_name}."
+ "\n\tIf this is unexpected, try running with `reload_cache`."
f"No merge deletes found w/ {self.table_name} & "
+ f"{restriction}.\n\tIf this is unexpected, try running with "
+ "`reload_cache`."
)

if dry_run:
Expand Down Expand Up @@ -352,7 +359,7 @@ def _usage_table(self):
"""Temporary inclusion for usage tracking."""
from spyglass.common.common_usage import CautiousDelete

return CautiousDelete
return CautiousDelete()

def _log_use(self, start, merge_deletes=None):
"""Log use of cautious_delete."""
Expand All @@ -361,7 +368,9 @@ def _log_use(self, start, merge_deletes=None):
duration=time() - start,
dj_user=dj.config["database.user"],
origin=self.full_table_name,
restriction=self.restriction,
restriction=(
str(self.restriction)[:255] if self.restriction else "None"
),
merge_deletes=merge_deletes,
)
)
Expand Down Expand Up @@ -419,10 +428,15 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs):

self._log_use(start=start, merge_deletes=merge_deletes)

def cdel(self, *args, **kwargs):
def cdel(self, force_permission=False, *args, **kwargs):
"""Alias for cautious_delete."""
self.cautious_delete(*args, **kwargs)
self.cautious_delete(force_permission=force_permission, *args, **kwargs)

def delete(self, *args, **kwargs):
def delete(self, force_permission=False, *args, **kwargs):
"""Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
self.cautious_delete(*args, **kwargs)
self.cautious_delete(force_permission=force_permission, *args, **kwargs)

def super_delete(self, *args, **kwargs):
"""Alias for datajoint.table.Table.delete."""
logger.warning("!! Using super_delete. Bypassing cautious_delete !!")
super().delete(*args, **kwargs)
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,18 @@ def mini_closed(mini_path):

@pytest.fixture(autouse=True, scope="session")
def mini_insert(mini_path, teardown, server, dj_conn):
from spyglass.common import Nwbfile, Session # noqa: E402
from spyglass.common import LabMember, Nwbfile, Session # noqa: E402
from spyglass.data_import import insert_sessions # noqa: E402
from spyglass.spikesorting.merge import SpikeSortingOutput # noqa: E402
from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402

LabMember().insert1(
["Root User", "Root", "User"], skip_duplicates=not teardown
)
LabMember.LabMemberInfo().insert1(
["Root User", "email", "root", 1], skip_duplicates=not teardown
)

dj_logger.info("Inserting test data.")

if not server.connected:
Expand Down
Loading