Skip to content

Commit

Permalink
Remove delete_downstream. Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 27, 2024
1 parent 03b87ce commit 2a5bc74
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 290 deletions.
2 changes: 1 addition & 1 deletion src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def is_merge_table(table):
def trim_def(definition):
return re_sub(
r"\n\s*\n", "\n", re_sub(r"#.*\n", "\n", definition.strip())
)
).replace(" ", "")

if isinstance(table, str):
table = dj.FreeTable(dj.conn(), table)
Expand Down
232 changes: 25 additions & 207 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
from os import environ
from re import match as re_match
from time import time
from typing import Dict, List, Union
from typing import List

import datajoint as dj
from datajoint.condition import make_condition
from datajoint.errors import DataJointError
from datajoint.expression import QueryExpression
from datajoint.logging import logger as dj_logger
from datajoint.table import Table
from datajoint.utils import get_master, to_camel_case, user_choice
from networkx import NetworkXError
from datajoint.utils import to_camel_case
from packaging.version import parse as version_parse
from pandas import DataFrame
from pymysql.err import DataError
Expand Down Expand Up @@ -52,14 +50,6 @@ class SpyglassMixin:
Fetch NWBFile object from relevant table. Uses either a foreign key to
a NWBFile table (including AnalysisNwbfile) or a _nwb_table attribute to
determine which table to use.
delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False)
Delete downstream merge table entries associated with restriction.
Requires caching of merge tables and links, which is slow on first call.
`restriction` can be set to a string to restrict the delete. `dry_run`
can be set to False to commit the delete. `reload_cache` can be set to
True to reload the merge cache.
ddp(*args, **kwargs)
Alias for delete_downstream_parts
cautious_delete(force_permission=False, *args, **kwargs)
Check user permissions before deleting table rows. Permission is granted
to users listed as admin in LabMember table or to users on a team with
Expand All @@ -68,8 +58,6 @@ class SpyglassMixin:
delete continues. If the Session has no experimenter, or if the user is
not on a team with the Session experimenter(s), a PermissionError is
raised. `force_permission` can be set to True to bypass permission check.
cdel(*args, **kwargs)
Alias for cautious_delete.
"""

# _nwb_table = None # NWBFile table class, defined at the table level
Expand Down Expand Up @@ -134,15 +122,17 @@ def file_like(self, name=None, **kwargs):

def find_insert_fail(self, key):
"""Find which parent table is causing an IntergrityError on insert."""
rets = []
for parent in self.parents(as_objects=True):
parent_key = {
k: v for k, v in key.items() if k in parent.heading.names
}
parent_name = to_camel_case(parent.table_name)
if query := parent & parent_key:
logger.info(f"{parent_name}:\n{query}")
rets.append(f"{parent_name}:\n{query}")
else:
logger.info(f"{parent_name}: MISSING")
rets.append(f"{parent_name}: MISSING")
logger.info("\n".join(rets))

@classmethod
def _safe_context(cls):
Expand Down Expand Up @@ -298,163 +288,6 @@ def load_shared_schemas(self, additional_prefixes: list = None) -> None:
for schema in schemas:
dj.schema(schema[0]).connection.dependencies.load()

@cached_property
def _part_masters(self) -> set:
"""Set of master tables downstream of self.
Cache of masters in self.descendants(as_objects=True) with another
foreign key reference in the part. Used for delete_downstream_parts.
"""
self.connection.dependencies.load()
part_masters = set()

def search_descendants(parent):
for desc_name in parent.descendants():
if ( # Check if has master, is part
not (master := get_master(desc_name))
or master in part_masters # already in cache
or desc_name.replace("`", "").split("_")[0]
not in SHARED_MODULES
):
continue
desc = dj.FreeTable(self.connection, desc_name)
if not set(desc.parents()) - set([master]): # no other parent
continue
part_masters.add(master)
search_descendants(dj.FreeTable(self.connection, master))

try:
_ = search_descendants(self)
except NetworkXError:
try: # Attempt to import failing schema
self.load_shared_schemas()
_ = search_descendants(self)
except NetworkXError as e:
table_name = "".join(e.args[0].split("`")[1:4])
raise ValueError(f"Please import {table_name} and try again.")

logger.info(
f"Building part-parent cache for {self.camel_name}.\n\t"
+ f"Found {len(part_masters)} downstream part tables"
)

return part_masters

def _commit_downstream_delete(self, down_fts, start=None, **kwargs):
"""
Commit delete of downstream parts via down_fts. Logs with _log_delete.
Used by both delete_downstream_parts and cautious_delete.
"""
start = start or time()

safemode = (
dj.config.get("safemode", True)
if kwargs.get("safemode") is None
else kwargs["safemode"]
)
_ = kwargs.pop("safemode", None)

ran_deletes = True
if down_fts:
for down_ft in down_fts:
dj_logger.info(
f"Spyglass: Deleting {len(down_ft)} rows from "
+ f"{down_ft.full_table_name}"
)
if (
self._test_mode
or not safemode
or user_choice("Commit deletes?", default="no") == "yes"
):
for down_ft in down_fts: # safemode off b/c already checked
down_ft.delete(safemode=False, **kwargs)
else:
logger.info("Delete aborted.")
ran_deletes = False

self._log_delete(start, del_blob=down_fts if ran_deletes else None)

return ran_deletes

def delete_downstream_parts(
self,
restriction: str = None,
dry_run: bool = True,
reload_cache: bool = False,
disable_warning: bool = False,
return_graph: bool = False,
verbose: bool = False,
**kwargs,
) -> List[dj.FreeTable]:
"""Delete downstream merge table entries associated with restriction.
Requires caching of merge tables and links, which is slow on first call.
Parameters
----------
restriction : str, optional
Restriction to apply to merge tables. Default None. Will attempt to
use table restriction if None.
dry_run : bool, optional
If True, return list of merge part entries to be deleted. Default
True.
reload_cache : bool, optional
If True, reload merge cache. Default False.
disable_warning : bool, optional
If True, do not warn if no merge tables found. Default False.
return_graph: bool, optional
If True, return RestrGraph object used to identify downstream
tables. Default False, return list of part FreeTables.
True. If False, return dictionary of merge tables and their joins.
verbose : bool, optional
If True, call RestrGraph with verbose=True. Default False.
**kwargs : Any
Passed to datajoint.table.Table.delete.
"""
RestrGraph = self._graph_deps[1]

start = time()

if reload_cache:
_ = self.__dict__.pop("_part_masters", None)

_ = self._part_masters # load cache before loading graph
restriction = restriction or self.restriction or True

restr_graph = RestrGraph(
seed_table=self,
leaves={self.full_table_name: restriction},
direction="down",
cascade=True,
verbose=verbose,
)

if return_graph:
return restr_graph

down_fts = restr_graph.ft_from_list(
self._part_masters, sort_reverse=False
)

if not down_fts and not disable_warning:
logger.warning(
f"No part deletes found w/ {self.camel_name} & "
+ f"{restriction}.\n\tIf this is unexpected, try importing "
+ " Merge table(s) and running with `reload_cache`."
)

if dry_run:
return down_fts

self._commit_downstream_delete(down_fts, start, **kwargs)

def ddp(
self, *args, **kwargs
) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]:
"""Alias for delete_downstream_parts."""
return self.delete_downstream_parts(*args, **kwargs)

# ---------------------------- cautious_delete ----------------------------

@cached_property
Expand Down Expand Up @@ -597,15 +430,10 @@ def _check_delete_permission(self) -> None:
)
logger.info(f"Queueing delete for session(s):\n{sess_summary}")

@cached_property
def _cautious_del_tbl(self):
"""Temporary inclusion for usage tracking."""
def _log_delete(self, start, del_blob=None, super_delete=False):
"""Log use of super_delete."""
from spyglass.common.common_usage import CautiousDelete

return CautiousDelete()

def _log_delete(self, start, del_blob=None, super_delete=False):
"""Log use of cautious_delete."""
safe_insert = dict(
duration=time() - start,
dj_user=dj.config["database.user"],
Expand All @@ -614,19 +442,25 @@ def _log_delete(self, start, del_blob=None, super_delete=False):
restr_str = "Super delete: " if super_delete else ""
restr_str += "".join(self.restriction) if self.restriction else "None"
try:
self._cautious_del_tbl.insert1(
CautiousDelete().insert1(
dict(
**safe_insert,
restriction=restr_str[:255],
merge_deletes=del_blob,
)
)
except (DataJointError, DataError):
self._cautious_del_tbl.insert1(
dict(**safe_insert, restriction="Unknown")
)
CautiousDelete().insert1(dict(**safe_insert, restriction="Unknown"))

@cached_property
def _has_updated_dj_version(self):
"""Return True if DataJoint version is up to date."""
target_dj = version_parse("0.14.2")
ret = version_parse(dj.__version__) >= target_dj
if not ret:
logger.warning(f"Please update DataJoint to {target_dj} or later.")
return ret

# TODO: Intercept datajoint delete confirmation prompt for merge deletes
def cautious_delete(
self, force_permission: bool = False, dry_run=False, *args, **kwargs
):
Expand All @@ -638,10 +472,6 @@ def cautious_delete(
continues. If the Session has no experimenter, or if the user is not on
a team with the Session experimenter(s), a PermissionError is raised.
Potential downstream orphans are deleted first. These are master tables
whose parts have foreign keys to descendants of self. Then, rows from
self are deleted. Last, Nwbfile and IntervalList externals are deleted.
Parameters
----------
force_permission : bool, optional
Expand All @@ -653,47 +483,34 @@ def cautious_delete(
*args, **kwargs : Any
Passed to datajoint.table.Table.delete.
"""
start = time()

if len(self) == 0:
logger.warning(f"Table is empty. No need to delete.\n{self}")
return

if self._has_updated_dj_version:
kwargs["force_masters"] = True

external, IntervalList = self._delete_deps[3], self._delete_deps[4]

if not force_permission or dry_run:
self._check_delete_permission()

down_fts = self.delete_downstream_parts(
dry_run=True,
disable_warning=True,
)

if dry_run:
return (
down_fts,
IntervalList(), # cleanup func relies on downstream deletes
external["raw"].unused(),
external["analysis"].unused(),
)

if not self._commit_downstream_delete(down_fts, start=start, **kwargs):
return # Abort delete based on user input

super().delete(*args, **kwargs) # Confirmation here

for ext_type in ["raw", "analysis"]:
external[ext_type].delete(
delete_external_files=True, display_progress=False
)

_ = IntervalList().nightly_cleanup(dry_run=False)

self._log_delete(start=start, del_blob=down_fts)

def cdel(self, *args, **kwargs):
"""Alias for cautious_delete."""
return self.cautious_delete(*args, **kwargs)
if not self._test_mode:
_ = IntervalList().nightly_cleanup(dry_run=False)

def delete(self, *args, **kwargs):
"""Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
Expand Down Expand Up @@ -728,6 +545,7 @@ def _hash_upstream(self, keys):
RestrGraph = self._graph_deps[1]

if not (parents := self.parents(as_objects=True, primary=True)):
# Should not happen, as this is only called from populated tables
raise RuntimeError("No upstream tables found for upstream hash.")

leaves = { # Restriction on each primary parent
Expand Down
6 changes: 0 additions & 6 deletions tests/common/test_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def param_table(common_position, default_param_key, teardown):
param_table = common_position.PositionInfoParameters()
param_table.insert1(default_param_key, skip_duplicates=True)
yield param_table
if teardown:
param_table.delete(safemode=False)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -61,8 +59,6 @@ def upsample_position(
)
common_position.IntervalPositionInfo.populate(interval_pos_key)
yield interval_pos_key
if teardown:
(param_table & upsample_param_key).delete(safemode=False)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -101,8 +97,6 @@ def upsample_position_error(
interval_pos_key, skip_duplicates=not teardown
)
yield interval_pos_key
if teardown:
(param_table & upsample_param_key).delete(safemode=False)


def test_interval_position_info_insert_error(
Expand Down
Loading

0 comments on commit 2a5bc74

Please sign in to comment.