Skip to content

Commit

Permalink
Address failing tests from delete overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 31, 2024
1 parent fa6d70e commit 51137c0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
12 changes: 9 additions & 3 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from spyglass.common.common_lab import Institution, Lab, LabMember
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_subject import Subject
from spyglass.settings import config, debug_mode
from spyglass.settings import config, debug_mode, test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file

Expand Down Expand Up @@ -214,6 +214,8 @@ def add_session_to_group(
*,
skip_duplicates: bool = False,
):
if test_mode:
skip_duplicates = True
SessionGroupSession.insert1(
{
"session_group_name": session_group_name,
Expand All @@ -230,12 +232,16 @@ def remove_session_from_group(
"session_group_name": session_group_name,
"nwb_file_name": nwb_file_name,
}
(SessionGroupSession & query).delete(*args, **kwargs)
(SessionGroupSession & query).delete(
force_permission=test_mode, *args, **kwargs
)

@staticmethod
def delete_group(session_group_name: str, *args, **kwargs):
query = {"session_group_name": session_group_name}
(SessionGroup & query).delete(*args, **kwargs)
(SessionGroup & query).delete(
force_permission=test_mode, *args, **kwargs
)

@staticmethod
def get_group_sessions(session_group_name: str):
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ class CautiousDelete(dj.Manual):
dj_user: varchar(64)
duration: float
origin: varchar(64)
restriction: varchar(64)
restriction: varchar(255)
merge_deletes = null: blob
"""
4 changes: 2 additions & 2 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datajoint.table import Table
from datajoint.utils import get_master

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.logging import logger

Expand Down Expand Up @@ -66,7 +65,7 @@ def __repr__(self):
def __len__(self):
return len([c for c in self.chains if c.has_link])

def __getitem__(self, index: Union[int, str]) -> TableChain:
def __getitem__(self, index: Union[int, str]):
"""Return FreeTable object at index."""
if isinstance(index, str):
for i, part in enumerate(self.part_names):
Expand Down Expand Up @@ -234,6 +233,7 @@ def join(self, restricton: str = None) -> dj.expression.QueryExpression:
try:
join = join.proj() * table
except dj.DataJointError as e:
attribute = str(e).split("attribute ")[-1]
logger.error(
f"{str(self)} at {table.table_name} with {attribute}"
)
Expand Down
28 changes: 18 additions & 10 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def _nwb_table_tuple(self) -> tuple:
Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb.
Implemented as a cached_property to avoid circular imports."""
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile # noqa F401
from spyglass.common.common_nwbfile import (
AnalysisNwbfile,
Nwbfile,
) # noqa F401

table_dict = {
AnalysisNwbfile: "analysis_file_abs_path",
Expand All @@ -71,9 +74,7 @@ def _nwb_table_tuple(self) -> tuple:
resolved = getattr(self, "_nwb_table", None) or (
AnalysisNwbfile
if "-> AnalysisNwbfile" in self.definition
else Nwbfile
if "-> Nwbfile" in self.definition
else None
else Nwbfile if "-> Nwbfile" in self.definition else None
)

if not resolved:
Expand Down Expand Up @@ -358,7 +359,7 @@ def _usage_table(self):
"""Temporary inclusion for usage tracking."""
from spyglass.common.common_usage import CautiousDelete

return CautiousDelete
return CautiousDelete()

def _log_use(self, start, merge_deletes=None):
"""Log use of cautious_delete."""
Expand All @@ -367,7 +368,9 @@ def _log_use(self, start, merge_deletes=None):
duration=time() - start,
dj_user=dj.config["database.user"],
origin=self.full_table_name,
restriction=self.restriction,
restriction=(
str(self.restriction)[:255] if self.restriction else "None"
),
merge_deletes=merge_deletes,
)
)
Expand Down Expand Up @@ -425,10 +428,15 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs):

self._log_use(start=start, merge_deletes=merge_deletes)

def cdel(self, *args, **kwargs):
def cdel(self, force_permission=False, *args, **kwargs):
"""Alias for cautious_delete."""
self.cautious_delete(*args, **kwargs)
self.cautious_delete(force_permission=force_permission, *args, **kwargs)

def delete(self, *args, **kwargs):
def delete(self, force_permission=False, *args, **kwargs):
"""Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
self.cautious_delete(*args, **kwargs)
self.cautious_delete(force_permission=force_permission, *args, **kwargs)

def super_delete(self, *args, **kwargs):
"""Alias for datajoint.table.Table.delete."""
logger.warning("!! Using super_delete. Bypassing cautious_delete !!")
super().delete(*args, **kwargs)
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,18 @@ def mini_closed(mini_path):

@pytest.fixture(autouse=True, scope="session")
def mini_insert(mini_path, teardown, server, dj_conn):
from spyglass.common import Nwbfile, Session # noqa: E402
from spyglass.common import LabMember, Nwbfile, Session # noqa: E402
from spyglass.data_import import insert_sessions # noqa: E402
from spyglass.spikesorting.merge import SpikeSortingOutput # noqa: E402
from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402

LabMember().insert1(
["Root User", "Root", "User"], skip_duplicates=not teardown
)
LabMember.LabMemberInfo().insert1(
["Root User", "email", "root", 1], skip_duplicates=not teardown
)

dj_logger.info("Inserting test data.")

if not server.connected:
Expand Down

0 comments on commit 51137c0

Please sign in to comment.