From e6f64faa4b078bcf4172421e97e08babc75b64b5 Mon Sep 17 00:00:00 2001 From: Sam Bray Date: Fri, 12 Apr 2024 16:48:52 -0700 Subject: [PATCH 01/17] initial commit for restrict_from_upstream --- src/spyglass/utils/dj_mixin.py | 89 ++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 082116bf6..529422739 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -15,6 +15,7 @@ from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.dj_merge_tables import Merge from spyglass.utils.logging import logger try: @@ -535,3 +536,91 @@ def super_delete(self, *args, **kwargs): logger.warning("!! Using super_delete. Bypassing cautious_delete !!") self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) + + from spyglass.utils.dj_merge_tables import Merge + + def restrict_from_upstream(self, key, **kwargs): + """Recursive function to restrict a table based on secondary keys of upstream tables""" + return restrict_from_upstream(self, key, **kwargs) + + +def restrict_from_upstream(table, key, max_recursion=3): + """Recursive function to restrict a table based on secondary keys of upstream tables""" + print(f"table: {table.full_table_name}, key: {key}") + # Tables not to recurse through because too big or central + blacklist = [ + "`common_nwbfile`.`analysis_nwbfile`", + ] + + # Case: MERGE table + if (table := table & key) and max_recursion: + if isinstance(table, Merge): + parts = table.parts(as_objects=True) + restricted_parts = [ + restrict_from_upstream(part, key, max_recursion - 1) + for part in parts + ] + # only keep entries from parts that got restricted + restricted_parts = [ + r_part.proj("merge_id") + for r_part, part in zip(restricted_parts, parts) + if ( + not len(r_part) == len(part) + or check_complete_restrict(r_part, key) + ) + ] + # return the merge of the restricted parts + merge_keys = [] + for r_part in restricted_parts: + merge_keys.extend(r_part.fetch("merge_id", as_dict=True)) + return table & merge_keys + + # Case: regular table + upstream_tables = table.parents(as_objects=True) + # prevent a loop where call Merge master table from part + upstream_tables = [ + parent + for parent in upstream_tables + if not ( + isinstance(parent, Merge) + and table.full_table_name in parent.parts() + ) + and (parent.full_table_name not in blacklist) + ] + for parent in upstream_tables: + print(parent.full_table_name) + print(len(parent)) + r_parent = restrict_from_upstream(parent, key, max_recursion - 1) + if len(r_parent) == len(parent): + continue # skip joins with uninformative tables + table = safe_join(table, r_parent) + if check_complete_restrict(table, key) or not table: + print(len(table)) + break + return table + + +def check_complete_restrict(table, key): + """Checks all keys in a restriction dictionary are used in a table""" + if all([k in table.heading.names for k in key.keys()]): + print("FOUND") + return all([k in table.heading.names for k in key.keys()]) + + +# Utility Function +def safe_join(table_1, table_2): + """enables joining of two tables with overlapping secondary keys""" + secondary_1 = [ + name + for name in table_1.heading.names + if name not in table_1.primary_key + ] + secondary_2 = [ + name + for name in table_2.heading.names + if name not in table_2.primary_key + ] + overlap = [name for name in secondary_1 if name in secondary_2] + return table_1 * table_2.proj( + *[name for name in table_2.heading.names if name not in overlap] + ) From ca0c44e644b363ec66857e635a53cc65181273a5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 25 Apr 2024 12:23:40 -0500 Subject: [PATCH 02/17] Add tests for RestrGraph --- pyproject.toml | 6 +-- src/spyglass/utils/dj_graph.py | 12 ++++-- tests/utils/conftest.py | 2 + tests/utils/test_graph.py | 70 ++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 tests/utils/test_graph.py diff --git a/pyproject.toml b/pyproject.toml index ffb8d0df6..45617385b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,10 +121,10 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ "-sv", - # "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure + "--sw", # stepwise: resume with next test after failure + "--pdb", # drop into debugger on failure "-p no:warnings", - # "--no-teardown", # don't teardown the database after tests + "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 59e7497d5..366ed26ad 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -22,6 +22,7 @@ def __init__( table_name: str = None, restriction: str = None, leaves: List[Dict[str, str]] = None, + cascade: bool = False, verbose: bool = False, **kwargs, ): @@ -38,6 +39,9 @@ def __init__( leaves : Dict[str, str], optional List of dictionaries with keys table_name and restriction. One entry per leaf node. Default None. + cascade : bool, optional + Whether to cascade restrictions up the graph on initialization. + Default False verbose : bool, optional Whether to print verbose output. Default False """ @@ -58,6 +62,9 @@ def __init__( if leaves: self.add_leaves(leaves, show_progress=verbose) + if cascade: + self.cascade() + def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" processed = "Cascaded" if self.cascaded else "Uncascaded" @@ -126,9 +133,6 @@ def _set_restr(self, table, restriction): ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) - # if table == "`spikesorting_merge`.`spike_sorting_output`": - # __import__("pdb").set_trace() - self._set_node(table, "restr", restriction) def get_restr_ft(self, table: Union[int, str]): @@ -285,6 +289,8 @@ def add_leaf(self, table_name, restriction, cascade=False) -> None: restriction : str restriction to apply to leaf """ + self.cascaded = False + new_ancestors = set(self._get_ft(table_name).ancestors()) self.ancestors |= new_ancestors # Add to total ancestors self.visited -= new_ancestors # Remove from visited to revisit diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 3503f9649..b2a2dcd2a 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -38,6 +38,8 @@ def chains(Nwbfile): ) # noqa: F401 from spyglass.position.position_merge import PositionOutput # noqa: F401 + _ = LFPOutput, LinearizedPositionOutput, PositionOutput + yield Nwbfile._get_chain("linear") diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py new file mode 100644 index 000000000..8d3f8a699 --- /dev/null +++ b/tests/utils/test_graph.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def leaf(lin_merge): + yield lin_merge.LinearizedPositionV1() + + +@pytest.fixture(scope="session") +def restr_graph(leaf): + from spyglass.utils.dj_graph import RestrGraph + + yield RestrGraph( + seed_table=leaf, + table_name=leaf.full_table_name, + restriction=True, + cascade=True, + verbose=True, + ) + + +def test_rg_repr(restr_graph, leaf): + """Test that the repr of a RestrGraph object is as expected.""" + repr_got = repr(restr_graph) + + assert "cascade" in repr_got.lower(), "Cascade not in repr." + assert leaf.full_table_name in repr_got, "Table name not in repr." + + +def test_rg_ft(restr_graph): + """Test FreeTable attribute of RestrGraph.""" + assert len(restr_graph.leaf_ft) == 1, "Unexpected number of leaf tables." + assert len(restr_graph.all_ft) == 9, "Unexpected number of cascaded tables." + + +def test_rg_restr_ft(restr_graph): + """Test get restricted free tables.""" + ft = restr_graph.get_restr_ft(1) + assert len(ft) == 1, "Unexpected restricted table length." + + +def test_rg_file_paths(restr_graph): + """Test collection of upstream file paths.""" + paths = [p.get("file_path") for p in restr_graph.file_paths] + assert len(paths) == 1, "Unexpected number of file paths." + assert all([Path(p).exists() for p in paths]), "Not all file paths exist." + + +@pytest.fixture(scope="session") +def restr_graph_new_leaf(restr_graph, common): + restr_graph.add_leaf( + table_name=common.common_behav.PositionSource.full_table_name, + restriction=True, + ) + + yield restr_graph + + +def test_add_leaf_cascade(restr_graph_new_leaf): + assert ( + not restr_graph_new_leaf.cascaded + ), "Cascaded flag not set when add leaf." + + +def test_add_leaf_restr_ft(restr_graph_new_leaf): + restr_graph_new_leaf.cascade() + ft = restr_graph_new_leaf.get_restr_ft("`common_interval`.`interval_list`") + assert len(ft) == 2, "Unexpected restricted table length." From f1aa8727e0f73adf142512c0cf4825fd30ae5267 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 25 Apr 2024 12:53:26 -0500 Subject: [PATCH 03/17] WIP: ABC for RestrGraph --- src/spyglass/utils/dj_graph.py | 186 ++++++++++++++++++--------------- 1 file changed, 103 insertions(+), 83 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 366ed26ad..869ccd294 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,6 +3,7 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ +from abc import ABC, abstractmethod from typing import Dict, List, Union from datajoint import FreeTable @@ -15,71 +16,38 @@ from spyglass.utils.dj_helper_fn import unique_dicts -class RestrGraph: - def __init__( - self, - seed_table: Table, - table_name: str = None, - restriction: str = None, - leaves: List[Dict[str, str]] = None, - cascade: bool = False, - verbose: bool = False, - **kwargs, - ): - """Use graph to cascade restrictions up from leaves to all ancestors. +class AbstractGraph(ABC): + def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): + """Abstract class for graph traversal and restriction application. Parameters ---------- seed_table : Table Table to use to establish connection and graph - table_name : str, optional - Table name of single leaf, default None - restriction : str, optional - Restriction to apply to leaf. default None - leaves : Dict[str, str], optional - List of dictionaries with keys table_name and restriction. One - entry per leaf node. Default None. - cascade : bool, optional - Whether to cascade restrictions up the graph on initialization. - Default False verbose : bool, optional Whether to print verbose output. Default False """ - self.connection = seed_table.connection self.graph = seed_table.connection.dependencies self.graph.load() self.verbose = verbose - self.cascaded = False - self.ancestors = set() self.visited = set() - self.leaves = set() - self.analysis_pk = AnalysisNwbfile().primary_key - - if table_name and restriction: - self.add_leaf(table_name, restriction) - if leaves: - self.add_leaves(leaves, show_progress=verbose) - - if cascade: - self.cascade() - - def __repr__(self): - l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" - processed = "Cascaded" if self.cascaded else "Uncascaded" - return f"{processed} RestrictionGraph(\n\t{l_str})" + self.to_visit = set() + self.cascaded = False - @property - def all_ft(self): - """Get restricted FreeTables from all visited nodes.""" - self.cascade() - return [self._get_ft(table, with_restr=True) for table in self.visited] + def _log_truncate(self, log_str, max_len=80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) - @property - def leaf_ft(self): - """Get restricted FreeTables from graph leaves.""" - return [self._get_ft(table, with_restr=True) for table in self.leaves] + @abstractmethod + def cascade(self): + """Cascade restrictions through graph.""" + raise NotImplementedError("Child class mut implement `cascade` method") def _get_node(self, table): """Get node from graph.""" @@ -95,25 +63,11 @@ def _set_node(self, table, attr="ft", value=None): _ = self._get_node(table) # Ensure node exists self.graph.nodes[table][attr] = value - def _get_ft(self, table, with_restr=False): - """Get FreeTable from graph node. If one doesn't exist, create it.""" - table = table if isinstance(table, str) else table.full_table_name - restr = self._get_restr(table) if with_restr else True - if ft := self._get_node(table).get("ft"): - return ft & restr - ft = FreeTable(self.connection, table) - self._set_node(table, "ft", ft) - return ft & restr - def _get_restr(self, table): """Get restriction from graph node.""" table = table if isinstance(table, str) else table.full_table_name return self._get_node(table).get("restr", "False") - def _get_files(self, table): - """Get analysis files from graph node.""" - return self._get_node(table).get("files", []) - def _set_restr(self, table, restriction): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) @@ -122,7 +76,6 @@ def _set_restr(self, table, restriction): if not isinstance(restriction, str) else restriction ) - # orig_restr = restriction if existing := self._get_restr(table): if existing == restriction: return @@ -135,10 +88,26 @@ def _set_restr(self, table, restriction): self._set_node(table, "restr", restriction) + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + table = table if isinstance(table, str) else table.full_table_name + restr = self._get_restr(table) if with_restr else True + if ft := self._get_node(table).get("ft"): + return ft & restr + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) + return ft & restr + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes.""" + self.cascade() + return [self._get_ft(table, with_restr=True) for table in self.visited] + def get_restr_ft(self, table: Union[int, str]): """Get restricted FreeTable from graph node. - Currently used. May be useful for debugging. + Currently used for testing. Parameters ---------- @@ -149,14 +118,6 @@ def get_restr_ft(self, table: Union[int, str]): table = list(self.visited)[table] return self._get_ft(table, with_restr=True) - def _log_truncate(self, log_str, max_len=80): - """Truncate log lines to max_len and print if verbose.""" - if not self.verbose: - return - logger.info( - log_str[:max_len] + "..." if len(log_str) > max_len else log_str - ) - def _child_to_parent( self, child, @@ -213,6 +174,75 @@ def _child_to_parent( return ret + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.visited + if self._get_restr(table) + ] + + +class RestrGraph(AbstractGraph): + def __init__( + self, + seed_table: Table, + table_name: str = None, + restriction: str = None, + leaves: List[Dict[str, str]] = None, + cascade: bool = False, + verbose: bool = False, + **kwargs, + ): + """Use graph to cascade restrictions up from leaves to all ancestors. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + table_name : str, optional + Table name of single leaf, default None + restriction : str, optional + Restriction to apply to leaf. default None + leaves : Dict[str, str], optional + List of dictionaries with keys table_name and restriction. One + entry per leaf node. Default None. + cascade : bool, optional + Whether to cascade restrictions up the graph on initialization. + Default False + verbose : bool, optional + Whether to print verbose output. Default False + """ + super().__init__(seed_table, verbose=verbose) + + self.ancestors = set() + self.leaves = set() + self.analysis_pk = AnalysisNwbfile().primary_key + + if table_name and restriction: + self.add_leaf(table_name, restriction) + if leaves: + self.add_leaves(leaves, show_progress=verbose) + + if cascade: + self.cascade() + + def __repr__(self): + l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" + processed = "Cascaded" if self.cascaded else "Uncascaded" + return f"{processed} RestrictionGraph(\n\t{l_str})" + + @property + def leaf_ft(self): + """Get restricted FreeTables from graph leaves.""" + return [self._get_ft(table, with_restr=True) for table in self.leaves] + + def _get_files(self, table): + """Get analysis files from graph node.""" + return self._get_node(table).get("files", []) + def cascade_files(self): """Set node attribute for analysis files.""" for table in self.visited: @@ -341,16 +371,6 @@ def add_leaves( self.cascade() self.cascade_files() - @property - def as_dict(self) -> List[Dict[str, str]]: - """Return as a list of dictionaries of table_name: restriction""" - self.cascade() - return [ - {"table_name": table, "restriction": self._get_restr(table)} - for table in self.ancestors - if self._get_restr(table) - ] - @property def file_dict(self) -> Dict[str, List[str]]: """Return dictionary of analysis files from all visited nodes. From 2c83c9ee7b08b10ac6744ed471cc03f615bb804f Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 26 Apr 2024 14:31:09 -0500 Subject: [PATCH 04/17] WIP: ABC for RestrGraph 2 --- pyproject.toml | 2 +- src/spyglass/utils/dj_chains.py | 119 ++++++------ src/spyglass/utils/dj_graph.py | 245 ++++-------------------- src/spyglass/utils/dj_graph_abs.py | 294 +++++++++++++++++++++++++++++ src/spyglass/utils/dj_mixin.py | 87 ++++++--- tests/conftest.py | 6 +- tests/utils/test_chains.py | 4 +- tests/utils/test_mixin.py | 7 +- 8 files changed, 454 insertions(+), 310 deletions(-) create mode 100644 src/spyglass/utils/dj_graph_abs.py diff --git a/pyproject.toml b/pyproject.toml index 45617385b..c7669020c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ minversion = "7.0" addopts = [ "-sv", "--sw", # stepwise: resume with next test after failure - "--pdb", # drop into debugger on failure + # "--pdb", # drop into debugger on failure "-p no:warnings", "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index fe9cebc02..85c2b2a9d 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -6,9 +6,11 @@ import networkx as nx from datajoint.expression import QueryExpression from datajoint.table import Table -from datajoint.utils import get_master, to_camel_case +from datajoint.utils import to_camel_case +from spyglass.utils.dj_graph_abs import AbstractGraph, _fuzzy_get from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.dj_merge_tables import is_merge_table from spyglass.utils.logging import logger # Tables that should be excluded from the undirected graph when finding paths @@ -88,23 +90,21 @@ def max_len(self): def __getitem__(self, index: Union[int, str]): """Return FreeTable object at index.""" - if isinstance(index, str): - for i, part in enumerate(self.part_names): - if index in part: - return self.chains[i] - return self.chains[index] + return _fuzzy_get(index, self.part_names, self.chains) - def join(self, restriction=None) -> List[QueryExpression]: + def join( + self, restriction=None, reverse_order=False + ) -> List[QueryExpression]: """Return list of joins for each chain in self.chains.""" restriction = restriction or self.parent.restriction or True joins = [] for chain in self.chains: - if joined := chain.join(restriction): + if joined := chain.join(restriction, reverse_order=reverse_order): joins.append(joined) return joins -class TableChain: +class TableChain(AbstractGraph): """Class for representing a chain of tables. A chain is a sequence of tables from parent to child identified by @@ -117,9 +117,6 @@ class TableChain: Parent or origin of chain. child : Table Child or destination of chain. - _connection : datajoint.Connection, optional - Connection to database used to create FreeTable objects. Defaults to - parent.connection. _link_symbol : str Symbol used to represent the link between parent and child. Hardcoded to " -> ". @@ -134,10 +131,6 @@ class TableChain: Directed graph of parent's dependencies from datajoint.connection. names : List[str] List of full table names in chain. - objects : List[dj.FreeTable] - List of FreeTable objects for each table in chain. - attr_maps : List[dict] - List of attribute maps for each link in chain. path : OrderedDict[str, Dict[str, Union[dj.FreeTable,dict]]] Dictionary of full table names in chain. Keys are self.names Values are a dict of free_table (self.objects) and @@ -162,29 +155,24 @@ class TableChain: Return join of tables in chain with restriction applied to parent. """ - def __init__(self, parent: Table, child: Table, connection=None): - self._connection = connection or parent.connection - self.graph = self._connection.dependencies - self.graph.load() - - if ( # if child is a merge table - get_master(child.full_table_name) == "" - and MERGE_PK in child.heading.names - ): + def __init__( + self, + parent: Table, + child: Table, + verbose: bool = True, + ): + if is_merge_table(child): raise TypeError("Child is a merge table. Use TableChains instead.") + super().__init__(seed_table=parent, verbose=verbose) + _ = self._get_node(child.full_table_name) # ensure child is in graph + self._link_symbol = " -> " self.parent = parent self.child = child self.link_type = None self._searched = False - - if child.full_table_name not in self.graph.nodes: - logger.warning( - "Can't find item in graph. Try importing: " - + f"{child.full_table_name}" - ) - self._searched = True + self.undirect_graph = None def __str__(self): """Return string representation of chain: parent -> child.""" @@ -200,9 +188,7 @@ def __repr__(self): """Return full representation of chain: parent -> {links} -> child.""" if not self.has_link: return "No link" - return "Chain: " + self._link_symbol.join( - [t.table_name for t in self.objects] - ) + return "Chain: " + self._link_symbol.join(self.names) def __len__(self): """Return number of tables in chain.""" @@ -210,15 +196,8 @@ def __len__(self): return 0 return len(self.names) - def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: - """Return FreeTable object at index.""" - if not self.has_link: - return None - if isinstance(index, str): - for i, name in enumerate(self.names): - if index in name: - return self.objects[i] - return self.objects[index] + def __getitem__(self, index: Union[int, str]): + return _fuzzy_get(index, self.names, self.objects) @property def has_link(self) -> bool: @@ -231,15 +210,6 @@ def has_link(self) -> bool: _ = self.path return self.link_type is not None - def pk_link(self, src, trg, data) -> float: - """Return 1 if data["primary"] else float("inf"). - - Currently unused. Preserved for future debugging. shortest_path accepts - an option weight callable parameter. - nx.shortest_path(G, source, target,weight=pk_link) - """ - return 1 if data["primary"] else float("inf") - def find_path(self, directed=True) -> OrderedDict: """Return list of full table names in chain. @@ -265,17 +235,23 @@ def find_path(self, directed=True) -> OrderedDict: source code for comments on alias nodes. """ source, target = self.parent.full_table_name, self.child.full_table_name + if not directed: - self.graph = self.graph.to_undirected() - self.graph.remove_nodes_from(PERIPHERAL_TABLES) + self.undirect_graph = self.graph.to_undirected() + self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) + + search_graph = self.graph if directed else self.undirect_graph + try: - path = nx.shortest_path(self.graph, source, target) + path = nx.shortest_path(search_graph, source, target) except nx.NetworkXNoPath: return None except nx.NodeNotFound: self._searched = True return None + self.no_visit.update(set(self.graph.nodes) - set(path)) + ret = OrderedDict() prev_table = None for i, table in enumerate(path): @@ -288,7 +264,7 @@ def find_path(self, directed=True) -> OrderedDict: attr_map = self.graph[prev_table][table]["attr_map"] ret[prev_table]["attr_map"] = attr_map else: - free_table = dj.FreeTable(self._connection, table) + free_table = dj.FreeTable(self.connection, table) ret[table] = {"free_table": free_table, "attr_map": {}} prev_table = table return ret @@ -323,19 +299,32 @@ def objects(self) -> List[dj.FreeTable]: """ if not self.has_link: return None - return [v["free_table"] for v in self.path.values()] + return [self._get_ft(table, with_restr=False) for table in self.names] + + def cascade(self, restriction: str = None, direction: str = "up"): + if direction == "up": + start, end = self.child, self.parent + else: + start, end = self.parent, self.child + self.cascade1( + table=start.full_table_name, + restriction=start.restriction, + direction=direction, + ) + return self._get_ft(end.full_table_name, with_restr=True) - @cached_property - def attr_maps(self) -> List[dict]: - """Return list of attribute maps for each table in chain. + return self._get_ft(self.parent.full_table_name, with_restr=True) - Unused. Preserved for future debugging. - """ + def join( + self, restriction: str = None, reverse_order: bool = False + ) -> dj.expression.QueryExpression: if not self.has_link: return None - return [v["attr_map"] for v in self.path.values()] - def join( + direction = "down" if reverse_order else "up" + return self.cascade(restriction, direction) + + def old_join( self, restriction: str = None, reverse_order: bool = False ) -> dj.expression.QueryExpression: """Return join of tables in chain with restriction applied to parent. diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 869ccd294..d3d8769a1 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,188 +3,16 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Dict, List -from datajoint import FreeTable -from datajoint.condition import make_condition from datajoint.table import Table from tqdm import tqdm from spyglass.common import AnalysisNwbfile -from spyglass.utils import logger +from spyglass.utils.dj_graph_abs import AbstractGraph from spyglass.utils.dj_helper_fn import unique_dicts -class AbstractGraph(ABC): - def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): - """Abstract class for graph traversal and restriction application. - - Parameters - ---------- - seed_table : Table - Table to use to establish connection and graph - verbose : bool, optional - Whether to print verbose output. Default False - """ - self.connection = seed_table.connection - self.graph = seed_table.connection.dependencies - self.graph.load() - - self.verbose = verbose - self.visited = set() - self.to_visit = set() - self.cascaded = False - - def _log_truncate(self, log_str, max_len=80): - """Truncate log lines to max_len and print if verbose.""" - if not self.verbose: - return - logger.info( - log_str[:max_len] + "..." if len(log_str) > max_len else log_str - ) - - @abstractmethod - def cascade(self): - """Cascade restrictions through graph.""" - raise NotImplementedError("Child class mut implement `cascade` method") - - def _get_node(self, table): - """Get node from graph.""" - if not (node := self.graph.nodes.get(table)): - raise ValueError( - f"Table {table} not found in graph." - + "\n\tPlease import this table and rerun" - ) - return node - - def _set_node(self, table, attr="ft", value=None): - """Set attribute on node. General helper for various attributes.""" - _ = self._get_node(table) # Ensure node exists - self.graph.nodes[table][attr] = value - - def _get_restr(self, table): - """Get restriction from graph node.""" - table = table if isinstance(table, str) else table.full_table_name - return self._get_node(table).get("restr", "False") - - def _set_restr(self, table, restriction): - """Add restriction to graph node. If one exists, merge with new.""" - ft = self._get_ft(table) - restriction = ( # Convert to condition if list or dict - make_condition(ft, restriction, set()) - if not isinstance(restriction, str) - else restriction - ) - if existing := self._get_restr(table): - if existing == restriction: - return - join = ft & [existing, restriction] - if len(join) == len(ft & existing): - return # restriction is a subset of existing - restriction = make_condition( - ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() - ) - - self._set_node(table, "restr", restriction) - - def _get_ft(self, table, with_restr=False): - """Get FreeTable from graph node. If one doesn't exist, create it.""" - table = table if isinstance(table, str) else table.full_table_name - restr = self._get_restr(table) if with_restr else True - if ft := self._get_node(table).get("ft"): - return ft & restr - ft = FreeTable(self.connection, table) - self._set_node(table, "ft", ft) - return ft & restr - - @property - def all_ft(self): - """Get restricted FreeTables from all visited nodes.""" - self.cascade() - return [self._get_ft(table, with_restr=True) for table in self.visited] - - def get_restr_ft(self, table: Union[int, str]): - """Get restricted FreeTable from graph node. - - Currently used for testing. - - Parameters - ---------- - table : Union[int, str] - Table name or index in visited set - """ - if isinstance(table, int): - table = list(self.visited)[table] - return self._get_ft(table, with_restr=True) - - def _child_to_parent( - self, - child, - parent, - restriction, - attr_map=None, - primary=True, - **kwargs, - ) -> List[Dict[str, str]]: - """Given a child, child's restr, and parent, get parent's restr. - - Parameters - ---------- - child : str - child table name - parent : str - parent table name - restriction : str - restriction to apply to child - attr_map : dict, optional - dictionary mapping aliases across parend/child, as pulled from - DataJoint-assembled graph. Default None. Func will flip this dict - to convert from child to parent fields. - primary : bool, optional - Is parent in child's primary key? Default True. Also derived from - DataJoint-assembled graph. If True, project only primary key fields - to avoid secondary key collisions. - - Returns - ------- - List[Dict[str, str]] - List of dicts containing primary key fields for restricted parent - table. - """ - - # Need to flip attr_map to respect parent's fields - attr_reverse = ( - {v: k for k, v in attr_map.items() if k != v} if attr_map else {} - ) - child_ft = self._get_ft(child) - parent_ft = self._get_ft(parent).proj() - restr = restriction or self._get_restr(child_ft) or True - restr_child = child_ft & restr - - if primary: # Project only primary key fields to avoid collisions - join = restr_child.proj(**attr_reverse) * parent_ft - else: # Include all fields - join = restr_child.proj(..., **attr_reverse) * parent_ft - - ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) - - if len(ret) == len(parent_ft): - self._log_truncate(f"NULL restr {parent}") - - return ret - - @property - def as_dict(self) -> List[Dict[str, str]]: - """Return as a list of dictionaries of table_name: restriction""" - self.cascade() - return [ - {"table_name": table, "restriction": self._get_restr(table)} - for table in self.visited - if self._get_restr(table) - ] - - class RestrGraph(AbstractGraph): def __init__( self, @@ -217,7 +45,6 @@ def __init__( """ super().__init__(seed_table, verbose=verbose) - self.ancestors = set() self.leaves = set() self.analysis_pk = AnalysisNwbfile().primary_key @@ -239,20 +66,7 @@ def leaf_ft(self): """Get restricted FreeTables from graph leaves.""" return [self._get_ft(table, with_restr=True) for table in self.leaves] - def _get_files(self, table): - """Get analysis files from graph node.""" - return self._get_node(table).get("files", []) - - def cascade_files(self): - """Set node attribute for analysis files.""" - for table in self.visited: - ft = self._get_ft(table) - if not set(self.analysis_pk).issubset(ft.heading.names): - continue - files = (ft & self._get_restr(table)).fetch(*self.analysis_pk) - self._set_node(table, "files", files) - - def cascade1(self, table, restriction): + def old_cascade1(self, table, restriction, direction="up"): """Cascade a restriction up the graph, recursively on parents. Parameters @@ -265,7 +79,13 @@ def cascade1(self, table, restriction): self._set_restr(table, restriction) self.visited.add(table) - for parent, data in self.graph.parents(table).items(): + next_nodes = ( + self.graph.parents(table) + if direction == "up" + else self.graph.children(table) + ) + + for parent, data in next_nodes.items(): if parent in self.visited: continue @@ -279,7 +99,11 @@ def cascade1(self, table, restriction): **data, ) - self.cascade1(parent, parent_restr) # Parent set on recursion + self.cascade1( + table=parent, + restriction=parent_restr, + direction=direction, + ) def cascade(self, show_progress=None) -> None: """Cascade all restrictions up the graph. @@ -301,7 +125,7 @@ def cascade(self, show_progress=None) -> None: restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") self.cascade1(table, restr) - if not self.visited == self.ancestors: + if not self.visited == self.to_visit: raise RuntimeError( "Cascade: FAIL - incomplete cascade. Please post issue." ) @@ -322,7 +146,7 @@ def add_leaf(self, table_name, restriction, cascade=False) -> None: self.cascaded = False new_ancestors = set(self._get_ft(table_name).ancestors()) - self.ancestors |= new_ancestors # Add to total ancestors + self.to_visit |= new_ancestors # Add to total ancestors self.visited -= new_ancestors # Remove from visited to revisit self.leaves.add(table_name) @@ -371,23 +195,29 @@ def add_leaves( self.cascade() self.cascade_files() + # ----------------------------- File Handling ----------------------------- + + def _get_files(self, table): + """Get analysis files from graph node.""" + return self._get_node(table).get("files", []) + + def cascade_files(self): + """Set node attribute for analysis files.""" + for table in self.visited: + ft = self._get_ft(table) + if not set(self.analysis_pk).issubset(ft.heading.names): + continue + files = list((ft & self._get_restr(table)).fetch(*self.analysis_pk)) + self._set_node(table, "files", files) + @property def file_dict(self) -> Dict[str, List[str]]: """Return dictionary of analysis files from all visited nodes. Currently unused, but could be useful for debugging. """ - if not self.cascaded: - logger.warning("Uncascaded graph. Using leaves only.") - table_list = self.leaves - else: - table_list = self.visited - - return { - table: self._get_files(table) - for table in table_list - if any(self._get_files(table)) - } + self.cascade() + return self._get_attr_dict("files", default_factory=lambda: []) @property def file_paths(self) -> List[str]: @@ -397,11 +227,10 @@ def file_paths(self) -> List[str]: directly by the user. """ self.cascade() - unique_files = set( - [file for table in self.visited for file in self._get_files(table)] - ) return [ {"file_path": AnalysisNwbfile().get_abs_path(file)} - for file in unique_files + for file in set( + [f for files in self.file_dict.values() for f in files] + ) if file is not None ] diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py new file mode 100644 index 000000000..dd3e06f7b --- /dev/null +++ b/src/spyglass/utils/dj_graph_abs.py @@ -0,0 +1,294 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Union + +from datajoint import FreeTable, logger +from datajoint.condition import make_condition +from datajoint.table import Table +from networkx import NetworkXNoPath, NodeNotFound, shortest_path + + +class AbstractGraph(ABC): + def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): + """Abstract class for graph traversal and restriction application. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + verbose : bool, optional + Whether to print verbose output. Default False + """ + self.connection = seed_table.connection + self.graph = seed_table.connection.dependencies + self.graph.load() + + self.verbose = verbose + self.visited = set() + self.to_visit = set() + self.no_visit = set() + self.cascaded = False + + def _log_truncate(self, log_str, max_len=80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) + + @abstractmethod + def cascade(self): + """Cascade restrictions through graph.""" + raise NotImplementedError("Child class mut implement `cascade` method") + + def _get_node(self, table): + """Get node from graph.""" + if not isinstance(table, str): + table = table.full_table_name + if not (node := self.graph.nodes.get(table)): + raise ValueError( + f"Table {table} not found in graph." + + "\n\tPlease import this table and rerun" + ) + return node + + def _set_node(self, table, attr="ft", value=None): + """Set attribute on node. General helper for various attributes.""" + _ = self._get_node(table) # Ensure node exists + self.graph.nodes[table][attr] = value + + def _get_attr_dict( + self, attr, default_factory=lambda: None + ) -> List[Dict[str, str]]: + """Get given attr for each table in self.visited + + Uses factory to create default value for missing attributes. + """ + return { + t: self._get_node(t).get(attr, default_factory()) + for t in self.visited + } + + def _get_attr_map_btwn(self, child, parent): + """Get attribute map between child and parent.""" + child = child if isinstance(child, str) else child.full_table_name + parent = parent if isinstance(parent, str) else parent.full_table_name + + reverse = False + try: + path = shortest_path(self.graph, child, parent) + except NetworkXNoPath: + reverse, child, parent = True, parent, child + path = shortest_path(self.graph, child, parent) + + if len(path) != 2 and not path[1].isnumeric(): + raise ValueError(f"{child} -> {parent} not direct path: {path}") + + try: + attr_map = self.graph[child][path[1]]["attr_map"] + except KeyError: + attr_map = self.graph[path[1]][child]["attr_map"] + return attr_map if not reverse else {v: k for k, v in attr_map.items()} + + def _get_restr(self, table): + """Get restriction from graph node.""" + table = table if isinstance(table, str) else table.full_table_name + return self._get_node(table).get("restr", "False") + + def _set_restr(self, table, restriction): + """Add restriction to graph node. If one exists, merge with new.""" + ft = self._get_ft(table) + restriction = ( # Convert to condition if list or dict + make_condition(ft, restriction, set()) + if not isinstance(restriction, str) + else restriction + ) + if existing := self._get_restr(table): + if existing == restriction: + return + join = ft & [existing, restriction] + if len(join) == len(ft & existing): + return # restriction is a subset of existing + restriction = make_condition( + ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() + ) + + self._set_node(table, "restr", restriction) + + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + table = table if isinstance(table, str) else table.full_table_name + restr = self._get_restr(table) if with_restr else True + if ft := self._get_node(table).get("ft"): + return ft & restr + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) + return ft & restr + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes.""" + self.cascade() + return [ + self._get_ft(table, with_restr=True) + for table in self.visited + if not table.isnumeric() + ] + + def get_restr_ft(self, table: Union[int, str]): + """Get restricted FreeTable from graph node. + + Currently used for testing. + + Parameters + ---------- + table : Union[int, str] + Table name or index in visited set + """ + if isinstance(table, int): + table = list(self.visited)[table] + return self._get_ft(table, with_restr=True) + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.visited + if self._get_restr(table) + ] + + def _bridge_restr( + self, + table1: str, + table2: str, + restr1: str, + attr_map: dict = None, + primary: bool = True, + ): + ft1 = self._get_ft(table1) + ft2 = self._get_ft(table2) + if attr_map is None: + attr_map = self._get_attr_map_btwn(table1, table2) + + if table1 in ft2.children(): + table1, table2 = table2, table1 + ft1, ft2 = ft2, ft1 + + if primary: + join = ft1.proj(**attr_map) * ft2 + else: + join = ft1.proj(..., **attr_map) * ft2 + + return unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) + + def _child_to_parent( + self, + child, + parent, + restriction, + attr_map=None, + primary=True, + **kwargs, + ) -> List[Dict[str, str]]: + """Given a child, child's restr, and parent, get parent's restr. + + Parameters + ---------- + child : str + child table name + parent : str + parent table name + restriction : str + restriction to apply to child + attr_map : dict, optional + dictionary mapping aliases across parend/child, as pulled from + DataJoint-assembled graph. Default None. Func will flip this dict + to convert from child to parent fields. + primary : bool, optional + Is parent in child's primary key? Default True. Also derived from + DataJoint-assembled graph. If True, project only primary key fields + to avoid secondary key collisions. + + Returns + ------- + List[Dict[str, str]] + List of dicts containing primary key fields for restricted parent + table. + """ + + # Need to flip attr_map to respect parent's fields + attr_reverse = ( + {v: k for k, v in attr_map.items() if k != v} if attr_map else {} + ) + child_ft = self._get_ft(child) + parent_ft = self._get_ft(parent).proj() + restr = restriction or self._get_restr(child_ft) or True + restr_child = child_ft & restr + + if primary: # Project only primary key fields to avoid collisions + join = restr_child.proj(**attr_reverse) * parent_ft + else: # Include all fields + join = restr_child.proj(..., **attr_reverse) * parent_ft + + ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) + + if len(ret) == len(parent_ft): + self._log_truncate(f"NULL restr {parent}") + + return ret + + def cascade1(self, table, restriction, direction="up"): + """Cascade a restriction up the graph, recursively on parents. + + Parameters + ---------- + table : str + table name + restriction : str + restriction to apply + """ + self._set_restr(table, restriction) + self.visited.add(table) + + next_nodes = ( + self.graph.parents(table) + if direction == "up" + else self.graph.children(table) + ) + + for next_table, data in next_nodes.items(): + if next_table in self.visited or next_table in self.no_visit: + continue + + if next_table.isnumeric(): + next_table, data = self.graph.parents(next_table).popitem() + + parent_restr = self._child_to_parent( + child=table, + parent=next_table, + restriction=restriction, + **data, + ) + + self.cascade1( + table=next_table, + restriction=parent_restr, + direction=direction, + ) + + +def unique_dicts(list_of_dict): + """Remove duplicate dictionaries from a list.""" + return [dict(t) for t in {tuple(d.items()) for d in list_of_dict}] + + +def _fuzzy_get(index: Union[int, str], names: List[str], sources: List[str]): + """Given lists of items/names, return item at index or by substring.""" + if isinstance(index, int): + return sources[index] + for i, part in enumerate(names): + if index in part: + return sources[i] + return None diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 73f27c7cf..f882db126 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -13,7 +13,7 @@ from datajoint.expression import QueryExpression from datajoint.logging import logger as dj_logger from datajoint.table import Table -from datajoint.utils import get_master, user_choice +from datajoint.utils import get_master, to_camel_case, user_choice from networkx import NetworkXError from pymysql.err import DataError @@ -93,6 +93,33 @@ def __init__(self, *args, **kwargs): + self.full_table_name ) + # -------------------------- Misc helper methods -------------------------- + + @property + def camel_name(self): + """Return table name in camel case.""" + return to_camel_case(self.table_name) + + def _auto_increment(self, key, pk, *args, **kwargs): + """Auto-increment primary key.""" + if not key.get(pk): + key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 + return key + + def file_like(self, name=None, **kwargs): + """Convenience method for wildcard search on file name fields.""" + if not name: + return self & True + attr = None + for field in self.heading.names: + if "file" in field: + attr = field + break + if not attr: + logger.error(f"No file-like field found in {self.full_table_name}") + return + return self & f"{attr} LIKE '%{name}%'" + # ------------------------------- fetch_nwb ------------------------------- @cached_property @@ -203,6 +230,26 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_merge ------------------------ + def import_merge_tables(self): + """Import all merge tables downstream of self.""" + from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 + from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 + from spyglass.linearization.merge import ( + LinearizedPositionOutput, + ) # noqa F401 + from spyglass.position.position_merge import PositionOutput # noqa F401 + from spyglass.spikesorting.spikesorting_merge import ( # noqa F401 + SpikeSortingOutput, + ) + + _ = ( + DecodingOutput(), + LFPOutput(), + LinearizedPositionOutput(), + PositionOutput(), + SpikeSortingOutput(), + ) + @cached_property def _merge_tables(self) -> Dict[str, dj.FreeTable]: """Dict of merge tables downstream of self: {full_table_name: FreeTable}. @@ -215,10 +262,6 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: visited = set() def search_descendants(parent): - # TODO: Add check that parents are in the graph. If not, raise error - # asking user to import the table. - # TODO: Make a `is_merge_table` helper, and check for false - # positives in the mixin init. for desc in parent.descendants(as_objects=True): if ( MERGE_PK not in desc.heading.names @@ -235,12 +278,16 @@ def search_descendants(parent): try: _ = search_descendants(self) - except NetworkXError as e: - table_name = "".join(e.args[0].split("`")[1:4]) - raise ValueError(f"Please import {table_name} and try again.") + except NetworkXError: + try: # Attempt to import missing table + self.import_merge_tables() + _ = search_descendants(self) + except NetworkXError as e: + table_name = "".join(e.args[0].split("`")[1:4]) + raise ValueError(f"Please import {table_name} and try again.") logger.info( - f"Building merge cache for {self.table_name}.\n\t" + f"Building merge cache for {self.camel_name}.\n\t" + f"Found {len(merge_tables)} downstream merge tables" ) @@ -716,27 +763,7 @@ def fetch1(self, *args, log_fetch=True, **kwargs): self._log_fetch(*args, **kwargs) return ret - # ------------------------- Other helper methods ------------------------- - - def _auto_increment(self, key, pk, *args, **kwargs): - """Auto-increment primary key.""" - if not key.get(pk): - key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 - return key - - def file_like(self, name=None, **kwargs): - """Convenience method for wildcard search on file name fields.""" - if not name: - return self & True - attr = None - for field in self.heading.names: - if "file" in field: - attr = field - break - if not attr: - logger.error(f"No file-like field found in {self.full_table_name}") - return - return self & f"{attr} LIKE '%{name}%'" + # ------------------------------ Restrict by ------------------------------ def restrict_from_upstream(self, key, **kwargs): """Recursive function to restrict a table based on secondary keys of upstream tables""" diff --git a/tests/conftest.py b/tests/conftest.py index 0bcb4a3fd..6c7b4e18a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -418,12 +418,16 @@ def trodes_pos_v1(teardown, sgp, trodes_sel_keys): def pos_merge_tables(dj_conn): """Return the merge tables as activated.""" from spyglass.common.common_position import TrackGraph + from spyglass.lfp.lfp_merge import LFPOutput from spyglass.linearization.merge import LinearizedPositionOutput from spyglass.position.position_merge import PositionOutput # must import common_position before LinOutput to avoid circular import - _ = TrackGraph() + + # import LFPOutput to use when testing mixin cascade + _ = LFPOutput() + return [PositionOutput(), LinearizedPositionOutput()] diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index 7ba4b1fa2..2ce8b8677 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -44,9 +44,7 @@ def test_chain_str(chain): def test_chain_repr(chain): """Test that the repr of a TableChain object is as expected.""" repr_got = repr(chain) - repr_ext = "Chain: " + chain._link_symbol.join( - [t.table_name for t in chain.objects] - ) + repr_ext = "Chain: " + chain._link_symbol.join(chain.names) assert repr_got == repr_ext, "Unexpected repr of TableChain object." diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index ac5c74bfe..7918d1b89 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -1,7 +1,7 @@ import datajoint as dj import pytest -from tests.conftest import VERBOSE +from tests.conftest import TEARDOWN, VERBOSE @pytest.fixture(scope="module") @@ -18,7 +18,10 @@ class Mixin(SpyglassMixin, dj.Manual): Mixin().drop_quick() -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +@pytest.mark.skipif( + not VERBOSE or not TEARDOWN, + reason="Error only on verbose or new declare.", +) def test_bad_prefix(caplog, dj_conn, Mixin): schema_bad = dj.Schema("badprefix", {}, connection=dj_conn) schema_bad(Mixin) From 98244f287f328d823ab05f6c45ebae8e802732dd Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 26 Apr 2024 19:07:24 -0500 Subject: [PATCH 05/17] WIP: ABC for RestrGraph 3 --- pyproject.toml | 2 +- src/spyglass/utils/dj_chains.py | 71 ++++++-------- src/spyglass/utils/dj_graph_abs.py | 150 +++++++++++++++-------------- src/spyglass/utils/dj_mixin.py | 9 +- tests/utils/test_mixin.py | 34 ++++++- 5 files changed, 149 insertions(+), 117 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7669020c..45617385b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ minversion = "7.0" addopts = [ "-sv", "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure + "--pdb", # drop into debugger on failure "-p no:warnings", "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index 85c2b2a9d..d1a9b1123 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -103,6 +103,15 @@ def join( joins.append(joined) return joins + def cascade(self, restriction: str = None, direction: str = "down"): + """Return list of cascades for each chain in self.chains.""" + restriction = restriction or self.parent.restriction or True + cascades = [] + for chain in self.chains: + if joined := chain.cascade(restriction, direction): + cascades.append(joined) + return cascades + class TableChain(AbstractGraph): """Class for representing a chain of tables. @@ -159,7 +168,7 @@ def __init__( self, parent: Table, child: Table, - verbose: bool = True, + verbose: bool = False, ): if is_merge_table(child): raise TypeError("Child is a merge table. Use TableChains instead.") @@ -188,7 +197,7 @@ def __repr__(self): """Return full representation of chain: parent -> {links} -> child.""" if not self.has_link: return "No link" - return "Chain: " + self._link_symbol.join(self.names) + return "Chain: " + self._link_symbol.join(self.path) def __len__(self): """Return number of tables in chain.""" @@ -250,46 +259,25 @@ def find_path(self, directed=True) -> OrderedDict: self._searched = True return None - self.no_visit.update(set(self.graph.nodes) - set(path)) - - ret = OrderedDict() - prev_table = None - for i, table in enumerate(path): - if table.isnumeric(): # get proj() attribute map for alias node - if not prev_table: - raise ValueError("Alias node found without prev table.") - try: - attr_map = self.graph[table][prev_table]["attr_map"] - except KeyError: # Why is this only DLCCentroid?? - attr_map = self.graph[prev_table][table]["attr_map"] - ret[prev_table]["attr_map"] = attr_map - else: - free_table = dj.FreeTable(self.connection, table) - ret[table] = {"free_table": free_table, "attr_map": {}} - prev_table = table - return ret + ignore_nodes = self.graph.nodes - set(path) + self.no_visit.update(ignore_nodes) + + return path @cached_property - def path(self) -> OrderedDict: + def path(self) -> list: """Return list of full table names in chain.""" if self._searched and not self.has_link: return None - link = None - if link := self.find_path(directed=True): + path = None + if path := self.find_path(directed=True): self.link_type = "directed" - elif link := self.find_path(directed=False): + elif path := self.find_path(directed=False): self.link_type = "undirected" self._searched = True - return link - - @cached_property - def names(self) -> List[str]: - """Return list of full table names in chain.""" - if not self.has_link: - return None - return list(self.path.keys()) + return path @cached_property def objects(self) -> List[dj.FreeTable]: @@ -299,22 +287,25 @@ def objects(self) -> List[dj.FreeTable]: """ if not self.has_link: return None - return [self._get_ft(table, with_restr=False) for table in self.names] + return [self._get_ft(table, with_restr=False) for table in self.path] def cascade(self, restriction: str = None, direction: str = "up"): + _ = self.path + if not self.has_link: + return None if direction == "up": start, end = self.child, self.parent else: start, end = self.parent, self.child - self.cascade1( - table=start.full_table_name, - restriction=start.restriction, - direction=direction, - ) + if not self.cascaded: + self.cascade1( + table=start.full_table_name, + restriction=restriction, + direction=direction, + ) + self.cascaded = True return self._get_ft(end.full_table_name, with_restr=True) - return self._get_ft(self.parent.full_table_name, with_restr=True) - def join( self, restriction: str = None, reverse_order: bool = False ) -> dj.expression.QueryExpression: diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py index dd3e06f7b..abb5f4d42 100644 --- a/src/spyglass/utils/dj_graph_abs.py +++ b/src/spyglass/utils/dj_graph_abs.py @@ -70,7 +70,9 @@ def _get_attr_dict( } def _get_attr_map_btwn(self, child, parent): - """Get attribute map between child and parent.""" + """Get attribute map between child and parent. + + Currently used for debugging.""" child = child if isinstance(child, str) else child.full_table_name parent = parent if isinstance(parent, str) else parent.full_table_name @@ -88,14 +90,27 @@ def _get_attr_map_btwn(self, child, parent): attr_map = self.graph[child][path[1]]["attr_map"] except KeyError: attr_map = self.graph[path[1]][child]["attr_map"] - return attr_map if not reverse else {v: k for k, v in attr_map.items()} + + return self._parse_attr_map(attr_map, reverse=reverse) + + def _parse_attr_map(self, attr_map, reverse=False): + """Parse attribute map. Remove self-references.""" + if not attr_map: + return {} + if reverse: + return {v: k for k, v in attr_map.items() if k != v} + return {k: v for k, v in attr_map.items() if k != v} def _get_restr(self, table): - """Get restriction from graph node.""" + """Get restriction from graph node. + + Defaults to False if no restriction is set so that it doesn't appear + in attrs like `all_ft`. + """ table = table if isinstance(table, str) else table.full_table_name return self._get_node(table).get("restr", "False") - def _set_restr(self, table, restriction): + def _set_restr(self, table, restriction, merge_existing=True): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) restriction = ( # Convert to condition if list or dict @@ -103,7 +118,8 @@ def _set_restr(self, table, restriction): if not isinstance(restriction, str) else restriction ) - if existing := self._get_restr(table): + existing = self._get_restr(table) + if merge_existing and existing != "False": # False is default if existing == restriction: return join = ft & [existing, restriction] @@ -163,49 +179,28 @@ def _bridge_restr( self, table1: str, table2: str, - restr1: str, + restr: str, attr_map: dict = None, primary: bool = True, + **kwargs, ): - ft1 = self._get_ft(table1) - ft2 = self._get_ft(table2) - if attr_map is None: - attr_map = self._get_attr_map_btwn(table1, table2) + """Given two tables and a restriction, return restriction for table2. - if table1 in ft2.children(): - table1, table2 = table2, table1 - ft1, ft2 = ft2, ft1 - - if primary: - join = ft1.proj(**attr_map) * ft2 - else: - join = ft1.proj(..., **attr_map) * ft2 - - return unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) - - def _child_to_parent( - self, - child, - parent, - restriction, - attr_map=None, - primary=True, - **kwargs, - ) -> List[Dict[str, str]]: - """Given a child, child's restr, and parent, get parent's restr. + Similar to ((table1 & restr) * table2).fetch(*table2.primary_key) + but with the ability to resolve aliases across tables. One table should + be the parent of the other. Replaces previous _child_to_parent. Parameters ---------- - child : str - child table name - parent : str - parent table name - restriction : str - restriction to apply to child + table1 : str + Table name. Restriction always applied to this table. + table2 : str + Table name. Restriction pulled from this table. + restr : str + Restriction to apply to table1. attr_map : dict, optional - dictionary mapping aliases across parend/child, as pulled from - DataJoint-assembled graph. Default None. Func will flip this dict - to convert from child to parent fields. + dictionary mapping aliases across tables, as pulled from + DataJoint-assembled graph. Default None. primary : bool, optional Is parent in child's primary key? Default True. Also derived from DataJoint-assembled graph. If True, project only primary key fields @@ -214,28 +209,38 @@ def _child_to_parent( Returns ------- List[Dict[str, str]] - List of dicts containing primary key fields for restricted parent - table. + List of dicts containing primary key fields for restricted table2. """ + ft1 = self._get_ft(table1) & restr + ft2 = self._get_ft(table2) - # Need to flip attr_map to respect parent's fields - attr_reverse = ( - {v: k for k, v in attr_map.items() if k != v} if attr_map else {} - ) - child_ft = self._get_ft(child) - parent_ft = self._get_ft(parent).proj() - restr = restriction or self._get_restr(child_ft) or True - restr_child = child_ft & restr + if len(ft1) == 0: + logger.warning(f"Empty table {table1} with restriction {restr}") + return ["False"] + + attr_map = self._parse_attr_map(attr_map) if attr_map else {} + + if table1 in ft2.parents(): + flip = False + child, parent = ft2, ft1 + else: # table2 in ft1.children() + flip = True + child, parent = ft1, ft2 + + if primary: + join = (parent.proj(**attr_map) * child).proj() + else: + join = (parent.proj(..., **attr_map) * child).proj() - if primary: # Project only primary key fields to avoid collisions - join = restr_child.proj(**attr_reverse) * parent_ft - else: # Include all fields - join = restr_child.proj(..., **attr_reverse) * parent_ft + if set(ft2.primary_key).isdisjoint(set(join.heading.names)): + join = join.proj(**self._parse_attr_map(attr_map, reverse=True)) - ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) + ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) - if len(ret) == len(parent_ft): - self._log_truncate(f"NULL restr {parent}") + if self.verbose and len(ft2) and len(ret) == len(ft2): + self._log_truncate(f"NULL restr {table2}") + if self.verbose and attr_map: + self._log_truncate(f"attr_map {table1} -> {table2}: {flip}") return ret @@ -249,32 +254,35 @@ def cascade1(self, table, restriction, direction="up"): restriction : str restriction to apply """ + self._set_restr(table, restriction) self.visited.add(table) - next_nodes = ( - self.graph.parents(table) - if direction == "up" - else self.graph.children(table) + next_func = ( + self.graph.parents if direction == "up" else self.graph.children ) - for next_table, data in next_nodes.items(): - if next_table in self.visited or next_table in self.no_visit: - continue - + for next_table, data in next_func(table).items(): if next_table.isnumeric(): - next_table, data = self.graph.parents(next_table).popitem() + next_table, data = next_func(next_table).popitem() + + if ( + next_table in self.visited + or next_table in self.no_visit + or table == next_table + ): + continue - parent_restr = self._child_to_parent( - child=table, - parent=next_table, - restriction=restriction, + next_restr = self._bridge_restr( + table1=table, + table2=next_table, + restr=restriction, **data, ) self.cascade1( table=next_table, - restriction=parent_restr, + restriction=next_restr, direction=direction, ) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index f882db126..50e023443 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -315,6 +315,7 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: # that the merge table with the longest chain is the most downstream. # A more sophisticated approach would order by length from self to # each merge part independently, but this is a good first approximation. + return OrderedDict( sorted( merge_chains.items(), key=lambda x: x[1].max_len, reverse=True @@ -377,20 +378,20 @@ def delete_downstream_merge( Passed to datajoint.table.Table.delete. """ if reload_cache: - del self._merge_tables - del self._merge_chains + for attr in ["_merge_tables", "_merge_chains"]: + _ = self.__dict__.pop(attr, None) restriction = restriction or self.restriction or True merge_join_dict = {} for name, chain in self._merge_chains.items(): - join = chain.join(restriction) + join = chain.cascade(restriction, direction="down") if join: merge_join_dict[name] = join if not merge_join_dict and not disable_warning: logger.warning( - f"No merge deletes found w/ {self.table_name} & " + f"No merge deletes found w/ {self.camel_name} & " + f"{restriction}.\n\tIf this is unexpected, try importing " + " Merge table(s) and running with `reload_cache`." ) diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 7918d1b89..5cfa9e478 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -43,6 +43,17 @@ def test_merge_detect(Nwbfile, pos_merge_tables): ), "Merges not detected by mixin." +def test_merge_chain_join(Nwbfile, pos_merge_tables): + """Test that the mixin can join merge chains.""" + all_chains = [ + chains.cascade(True, direction="down") + for chains in Nwbfile._merge_chains.values() + ] + end_len = [len(chain[0]) for chain in all_chains if chain] + + assert end_len == [1, 1, 2], "Merge chains not joined correctly." + + def test_get_chain(Nwbfile, pos_merge_tables): """Test that the mixin can get the chain of a merge.""" lin_parts = Nwbfile._get_chain("linear").part_names @@ -53,7 +64,28 @@ def test_get_chain(Nwbfile, pos_merge_tables): @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_ddm_warning(Nwbfile, caplog): """Test that the mixin warns on empty delete_downstream_merge.""" - (Nwbfile & "nwb_file_name LIKE 'BadName'").delete_downstream_merge( + (Nwbfile.file_like("BadName")).delete_downstream_merge( reload_cache=True, disable_warnings=False ) assert "No merge deletes found" in caplog.text, "No warning issued." + + +def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables): + """Test that the mixin can dry run delete_downstream_merge.""" + param_field = "trodes_pos_params_name" + trodes_params = sgp.v1.TrodesPosParams() + rft = next( + iter( + (trodes_params & f'{param_field} LIKE "%ups%"').ddm( + reload_cache=True, dry_run=True, return_parts=True + ) + ) + )[0] + assert len(rft) == 1, "ddm did not return restricted table." + + table_name = pos_merge_tables[0].parts()[-1] + assert table_name == rft.full_table_name, "ddm didn't grab right table." + + assert ( + rft.fetch1(param_field) == "single_led_upsampled" + ), "ddm didn't grab right row." From 319a6047be81be735bcb301f0167c86717625ddc Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 28 Apr 2024 13:17:51 -0500 Subject: [PATCH 06/17] WIP: Operator for 'find upstream key' --- src/spyglass/spikesorting/imported.py | 3 +- src/spyglass/utils/dj_chains.py | 99 ++------- src/spyglass/utils/dj_graph.py | 294 +++++++++++++++++--------- src/spyglass/utils/dj_graph_abs.py | 32 +-- src/spyglass/utils/dj_helper_fn.py | 28 ++- src/spyglass/utils/dj_merge_tables.py | 76 +++---- src/spyglass/utils/dj_mixin.py | 119 +++-------- src/spyglass/utils/nwb_helper_fn.py | 2 +- tests/utils/test_chains.py | 8 +- tests/utils/test_mixin.py | 2 +- 10 files changed, 323 insertions(+), 340 deletions(-) diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index ca1bdc9d0..ccb24edd2 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -51,9 +51,8 @@ def make(self, key): self.insert1(key, skip_duplicates=True) - part_name = SpikeSortingOutput._part_name(self.table_name) SpikeSortingOutput._merge_insert( - [orig_key], part_name=part_name, skip_duplicates=True + [orig_key], part_name=self.camel_name, skip_duplicates=True ) @classmethod diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index d1a9b1123..18a3d11ac 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -4,35 +4,18 @@ import datajoint as dj import networkx as nx -from datajoint.expression import QueryExpression from datajoint.table import Table from datajoint.utils import to_camel_case -from spyglass.utils.dj_graph_abs import AbstractGraph, _fuzzy_get -from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.dj_graph_abs import AbstractGraph +from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, fuzzy_get from spyglass.utils.dj_merge_tables import is_merge_table -from spyglass.utils.logging import logger - -# Tables that should be excluded from the undirected graph when finding paths -# to maintain valid joins. -PERIPHERAL_TABLES = [ - "`common_interval`.`interval_list`", - "`common_nwbfile`.`__analysis_nwbfile_kachery`", - "`common_nwbfile`.`__nwbfile_kachery`", - "`common_nwbfile`.`analysis_nwbfile_kachery_selection`", - "`common_nwbfile`.`analysis_nwbfile_kachery`", - "`common_nwbfile`.`analysis_nwbfile`", - "`common_nwbfile`.`kachery_channel`", - "`common_nwbfile`.`nwbfile_kachery_selection`", - "`common_nwbfile`.`nwbfile_kachery`", - "`common_nwbfile`.`nwbfile`", -] class TableChains: """Class for representing chains from parent to Merge table via parts. - Functions as a plural version of TableChain, allowing a single `join` + Functions as a plural version of TableChain, allowing a single `cascade` call across all chains from parent -> Merge table. Attributes @@ -64,8 +47,8 @@ class TableChains: Return number of chains with links. __getitem__(index: Union[int, str]) Return TableChain object at index, or use substring of table name. - join(restriction: str = None) - Return list of joins for each chain in self.chains. + cascade(restriction: str = None) + Return list of cascade for each chain in self.chains. """ def __init__(self, parent, child, connection=None): @@ -90,18 +73,7 @@ def max_len(self): def __getitem__(self, index: Union[int, str]): """Return FreeTable object at index.""" - return _fuzzy_get(index, self.part_names, self.chains) - - def join( - self, restriction=None, reverse_order=False - ) -> List[QueryExpression]: - """Return list of joins for each chain in self.chains.""" - restriction = restriction or self.parent.restriction or True - joins = [] - for chain in self.chains: - if joined := chain.join(restriction, reverse_order=reverse_order): - joins.append(joined) - return joins + return fuzzy_get(index, self.part_names, self.chains) def cascade(self, restriction: str = None, direction: str = "down"): """Return list of cascades for each chain in self.chains.""" @@ -160,8 +132,10 @@ class TableChain(AbstractGraph): True, uses directed graph. If False, uses undirected graph. Undirected excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain valid joins. - join(restriction: str = None) - Return join of tables in chain with restriction applied to parent. + cascade(restriction: str = None, direction: str = "up") + Given a restriction at the beginning, return a restricted FreeTable + object at the end of the chain. If direction is 'up', start at the child + and move up to the parent. If direction is 'down', start at the parent. """ def __init__( @@ -203,10 +177,10 @@ def __len__(self): """Return number of tables in chain.""" if not self.has_link: return 0 - return len(self.names) + return len(self.path) def __getitem__(self, index: Union[int, str]): - return _fuzzy_get(index, self.names, self.objects) + return fuzzy_get(index, self.path, self.all_ft) @property def has_link(self) -> bool: @@ -280,7 +254,7 @@ def path(self) -> list: return path @cached_property - def objects(self) -> List[dj.FreeTable]: + def all_ft(self) -> List[dj.FreeTable]: """Return list of FreeTable objects for each table in chain. Unused. Preserved for future debugging. @@ -302,52 +276,7 @@ def cascade(self, restriction: str = None, direction: str = "up"): table=start.full_table_name, restriction=restriction, direction=direction, + replace=True, ) self.cascaded = True return self._get_ft(end.full_table_name, with_restr=True) - - def join( - self, restriction: str = None, reverse_order: bool = False - ) -> dj.expression.QueryExpression: - if not self.has_link: - return None - - direction = "down" if reverse_order else "up" - return self.cascade(restriction, direction) - - def old_join( - self, restriction: str = None, reverse_order: bool = False - ) -> dj.expression.QueryExpression: - """Return join of tables in chain with restriction applied to parent. - - Parameters - ---------- - restriction : str, optional - Restriction to apply to first table in the order. - Defaults to self.parent.restriction. - reverse_order : bool, optional - If True, join tables in reverse order. Defaults to False. - """ - if not self.has_link: - return None - - restriction = restriction or self.parent.restriction or True - path = ( - OrderedDict(reversed(self.path.items())) - if reverse_order - else self.path - ).copy() - - _, first_val = path.popitem(last=False) - join = first_val["free_table"] & restriction - for i, val in enumerate(path.values()): - attr_map, free_table = val["attr_map"], val["free_table"] - try: - join = (join.proj() * free_table).proj(**attr_map) - except dj.DataJointError as e: - attribute = str(e).split("attribute ")[-1] - logger.error( - f"{str(self)} at {free_table.table_name} with {attribute}" - ) - return None - return join diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index d3d8769a1..461843018 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,14 +3,16 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from typing import Dict, List +from typing import Dict, List, Set, Tuple, Union +from datajoint.condition import make_condition from datajoint.table import Table from tqdm import tqdm from spyglass.common import AnalysisNwbfile +from spyglass.utils import logger from spyglass.utils.dj_graph_abs import AbstractGraph -from spyglass.utils.dj_helper_fn import unique_dicts +from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, unique_dicts class RestrGraph(AbstractGraph): @@ -45,13 +47,10 @@ def __init__( """ super().__init__(seed_table, verbose=verbose) - self.leaves = set() self.analysis_pk = AnalysisNwbfile().primary_key - if table_name and restriction: - self.add_leaf(table_name, restriction) - if leaves: - self.add_leaves(leaves, show_progress=verbose) + self.add_leaf(table_name=table_name, restriction=restriction) + self.add_leaves(leaves) if cascade: self.cascade() @@ -59,51 +58,95 @@ def __init__( def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" processed = "Cascaded" if self.cascaded else "Uncascaded" - return f"{processed} RestrictionGraph(\n\t{l_str})" + return f"{processed} {self.__class__.__name__}(\n\t{l_str})" @property def leaf_ft(self): """Get restricted FreeTables from graph leaves.""" return [self._get_ft(table, with_restr=True) for table in self.leaves] - def old_cascade1(self, table, restriction, direction="up"): - """Cascade a restriction up the graph, recursively on parents. + def add_leaf( + self, table_name=None, restriction=True, cascade=False + ) -> None: + """Add leaf to graph and cascade if requested. Parameters ---------- - table : str - table name - restriction : str - restriction to apply + table_name : str, optional + table name of leaf. Default None, do nothing. + restriction : str, optional + restriction to apply to leaf. Default True, no restriction. + cascade : bool, optional + Whether to cascade the restrictions up the graph. Default False. """ - self._set_restr(table, restriction) - self.visited.add(table) + if not table_name: + return - next_nodes = ( - self.graph.parents(table) - if direction == "up" - else self.graph.children(table) - ) + self.cascaded = False - for parent, data in next_nodes.items(): - if parent in self.visited: - continue + new_ancestors = set(self._get_ft(table_name).ancestors()) - if parent.isnumeric(): - parent, data = self.graph.parents(parent).popitem() + # TODO: adjust to permit cascade down, conditional to descendants + self.to_visit |= new_ancestors # Add to total ancestors + self.visited -= new_ancestors # Remove from visited to revisit - parent_restr = self._child_to_parent( - child=table, - parent=parent, - restriction=restriction, - **data, - ) + self.leaves.add(table_name) + self._set_restr(table_name, restriction) # Redundant if cascaded - self.cascade1( - table=parent, - restriction=parent_restr, - direction=direction, + if cascade: + self.cascade1(table_name, restriction) + self.cascade_files() + self.cascaded = True + + def _process_leaves(self, leaves=None, default_restriction=True): + """Process leaves to ensure they are unique and have required keys.""" + if not leaves: + return [] + if not isinstance(leaves, list): + leaves = [leaves] + if all(isinstance(leaf, str) for leaf in leaves): + leaves = [ + {"table_name": leaf, "restriction": default_restriction} + for leaf in leaves + ] + if all(isinstance(leaf, dict) for leaf in leaves) and not all( + leaf.get("table_name") for leaf in leaves + ): + raise ValueError(f"All leaves must have table_name: {leaves}") + + return unique_dicts(leaves) + + def add_leaves( + self, + leaves: Union[str, List, List[Dict[str, str]]] = None, + default_restriction: str = None, + cascade=False, + ) -> None: + """Add leaves to graph and cascade if requested. + + Parameters + ---------- + leaves : Union[str, List, List[Dict[str, str]]], optional + Table names of leaves, either as a list of strings or a list of + dictionaries with keys table_name and restriction. One entry per + leaf node. Default None, do nothing. + default_restriction : str, optional + Default restriction to apply to each leaf. Default True, no + restriction. Only used if leaf missing restriction. + cascade : bool, optional + Whether to cascade the restrictions up the graph. Default False + """ + leaves = self._process_leaves( + leaves=leaves, default_restriction=default_restriction + ) + for leaf in leaves: + self.add_leaf( + leaf.get("table_name"), + leaf.get("restriction"), + cascade=False, ) + if cascade: + self.cascade() def cascade(self, show_progress=None) -> None: """Cascade all restrictions up the graph. @@ -115,7 +158,9 @@ def cascade(self, show_progress=None) -> None: """ if self.cascaded: return + to_visit = self.leaves - self.visited + for table in tqdm( to_visit, desc="RestrGraph: cascading restrictions", @@ -133,68 +178,6 @@ def cascade(self, show_progress=None) -> None: self.cascade_files() self.cascaded = True - def add_leaf(self, table_name, restriction, cascade=False) -> None: - """Add leaf to graph and cascade if requested. - - Parameters - ---------- - table_name : str - table name of leaf - restriction : str - restriction to apply to leaf - """ - self.cascaded = False - - new_ancestors = set(self._get_ft(table_name).ancestors()) - self.to_visit |= new_ancestors # Add to total ancestors - self.visited -= new_ancestors # Remove from visited to revisit - - self.leaves.add(table_name) - self._set_restr(table_name, restriction) # Redundant if cascaded - - if cascade: - self.cascade1(table_name, restriction) - self.cascade_files() - self.cascaded = True - - def add_leaves( - self, leaves: List[Dict[str, str]], cascade=False, show_progress=None - ) -> None: - """Add leaves to graph and cascade if requested. - - Parameters - ---------- - leaves : List[Dict[str, str]] - list of dictionaries containing table_name and restriction - cascade : bool, optional - Whether to cascade the restrictions up the graph. Default False - show_progress : bool, optional - Show tqdm progress bar. Default to verbose setting. - """ - - if not leaves: - return - if not isinstance(leaves, list): - leaves = [leaves] - leaves = unique_dicts(leaves) - for leaf in tqdm( - leaves, - desc="RestrGraph: adding leaves", - total=len(leaves), - disable=not (show_progress or self.verbose), - ): - if not ( - (table_name := leaf.get("table_name")) - and (restriction := leaf.get("restriction")) - ): - raise ValueError( - f"Leaf must have table_name and restriction: {leaf}" - ) - self.add_leaf(table_name, restriction, cascade=False) - if cascade: - self.cascade() - self.cascade_files() - # ----------------------------- File Handling ----------------------------- def _get_files(self, table): @@ -234,3 +217,118 @@ def file_paths(self) -> List[str]: ) if file is not None ] + + +class FindKeyGraph(RestrGraph): + def __init__( + self, + seed_table: Table, + table_name: str = None, + restriction: str = None, + leaves: List[Dict[str, str]] = None, + direction: str = "up", + cascade: bool = False, + verbose: bool = False, + **kwargs, + ): + """Graph to restrict leaf by upstream keys. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + table_name : str, optional + Table name of single leaf, default seed_table.full_table_name + restriction : str, optional + Restriction to apply to leaf. default None, True + verbose : bool, optional + Whether to print verbose output. Default False + """ + + super().__init__(seed_table, verbose=verbose) + + if restriction and table_name: + self._set_find_restr(table_name, restriction) + self.add_leaf(table_name, True, cascade=False) + self.add_leaves(leaves, default_restriction=restriction, cascade=False) + + all_nodes = set([n for n in self.graph.nodes if not n.isnumeric()]) + self.no_visit.update(all_nodes - self.to_visit) # Skip non-ancestors + self.no_visit.update(PERIPHERAL_TABLES) + + if cascade and restriction: + self.cascade() + self.cascaded = True + + def _set_find_restr(self, table_name, restriction): + """Set restr to look for from leaf node.""" + if isinstance(restriction, dict): + logger.warning("key_from_upstream: DICT unreliable, use STR.") + restr_attrs = set(restriction.keys()) + else: + restr_attrs = set() # modified by make_condition + table_ft = self._get_ft(table_name) + _ = make_condition(table_ft, restriction, restr_attrs) + self._set_node(table_name, "find_restr", restriction) + self._set_node(table_name, "restr_attrs", restr_attrs) + + def _get_find_restr(self, table) -> Tuple[str, Set[str]]: + """Get restr and restr_attrs from leaf node.""" + node = self._get_node(table) + return node.get("find_restr", False), node.get("restr_attrs", set()) + + def add_leaves(self, leaves=None, default_restriction=None, cascade=False): + leaves = self._process_leaves( + leaves=leaves, default_restriction=default_restriction + ) + for leaf in leaves: # Multiple leaves + self._set_find_restr(**leaf) + self.add_leaf(leaf["table_name"], True, cascade=False) + + def cascade(self, show_progress=None) -> None: + for table in self.leaves: + restriction, restr_attrs = self._get_find_restr(table) + self.cascade1_search( + table=table, + restriction=restriction, + restr_attrs=restr_attrs, + replace=True, + ) + + # TODO: Add restrict_from_upstream method to Merge class, leaf per part + def cascade1_search( + self, + table: str, + restriction: str, + restr_attrs: Set[str] = None, + direction: str = "up", + replace: bool = True, + ): + next_func = ( + self.graph.parents if direction == "up" else self.graph.children + ) + + for next_table, data in next_func(table).items(): + if next_table.isnumeric(): + next_table, data = next_func(next_table).popitem() + + if next_table in self.no_visit or table == next_table: + continue + + next_ft = self._get_ft(next_table) + if restr_attrs.issubset(set(next_ft.heading.names)): + self.cascade1( + table=next_table, + restriction=restriction, + direction="down" if direction == "up" else "up", + replace=replace, + ) + return + + self.cascade1_search( + table=next_table, + restriction=restriction, + restr_attrs=restr_attrs, + direction=direction, + replace=replace, + ) diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py index abb5f4d42..146c7500c 100644 --- a/src/spyglass/utils/dj_graph_abs.py +++ b/src/spyglass/utils/dj_graph_abs.py @@ -4,7 +4,9 @@ from datajoint import FreeTable, logger from datajoint.condition import make_condition from datajoint.table import Table -from networkx import NetworkXNoPath, NodeNotFound, shortest_path +from networkx import NetworkXNoPath, shortest_path + +from spyglass.utils.dj_helper_fn import unique_dicts class AbstractGraph(ABC): @@ -23,6 +25,7 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): self.graph.load() self.verbose = verbose + self.leaves = set() self.visited = set() self.to_visit = set() self.no_visit = set() @@ -110,7 +113,7 @@ def _get_restr(self, table): table = table if isinstance(table, str) else table.full_table_name return self._get_node(table).get("restr", "False") - def _set_restr(self, table, restriction, merge_existing=True): + def _set_restr(self, table, restriction, replace=False): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) restriction = ( # Convert to condition if list or dict @@ -119,7 +122,7 @@ def _set_restr(self, table, restriction, merge_existing=True): else restriction ) existing = self._get_restr(table) - if merge_existing and existing != "False": # False is default + if not replace and existing != "False": # False is default if existing == restriction: return join = ft & [existing, restriction] @@ -215,10 +218,9 @@ def _bridge_restr( ft2 = self._get_ft(table2) if len(ft1) == 0: - logger.warning(f"Empty table {table1} with restriction {restr}") return ["False"] - attr_map = self._parse_attr_map(attr_map) if attr_map else {} + attr_map = self._parse_attr_map(attr_map) if table1 in ft2.parents(): flip = False @@ -244,7 +246,7 @@ def _bridge_restr( return ret - def cascade1(self, table, restriction, direction="up"): + def cascade1(self, table, restriction, direction="up", replace=False): """Cascade a restriction up the graph, recursively on parents. Parameters @@ -255,7 +257,7 @@ def cascade1(self, table, restriction, direction="up"): restriction to apply """ - self._set_restr(table, restriction) + self._set_restr(table, restriction, replace=replace) self.visited.add(table) next_func = ( @@ -284,19 +286,5 @@ def cascade1(self, table, restriction, direction="up"): table=next_table, restriction=next_restr, direction=direction, + replace=replace, ) - - -def unique_dicts(list_of_dict): - """Remove duplicate dictionaries from a list.""" - return [dict(t) for t in {tuple(d.items()) for d in list_of_dict}] - - -def _fuzzy_get(index: Union[int, str], names: List[str], sources: List[str]): - """Given lists of items/names, return item at index or by substring.""" - if isinstance(index, int): - return sources[index] - for i, part in enumerate(names): - if index in part: - return sources[i] - return None diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 7af1fb2b4..89b1950cd 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -2,16 +2,40 @@ import inspect import os -from typing import Type +from typing import List, Type, Union import datajoint as dj import numpy as np from datajoint.user_tables import UserTable -from spyglass.utils.dj_chains import PERIPHERAL_TABLES from spyglass.utils.logging import logger from spyglass.utils.nwb_helper_fn import get_nwb_file +# Tables that should be excluded from the undirected graph when finding paths +# for TableChain objects and searching for an upstream key. +PERIPHERAL_TABLES = [ + "`common_interval`.`interval_list`", + "`common_nwbfile`.`__analysis_nwbfile_kachery`", + "`common_nwbfile`.`__nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile_kachery_selection`", + "`common_nwbfile`.`analysis_nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile`", + "`common_nwbfile`.`kachery_channel`", + "`common_nwbfile`.`nwbfile_kachery_selection`", + "`common_nwbfile`.`nwbfile_kachery`", + "`common_nwbfile`.`nwbfile`", +] + + +def fuzzy_get(index: Union[int, str], names: List[str], sources: List[str]): + """Given lists of items/names, return item at index or by substring.""" + if isinstance(index, int): + return sources[index] + for i, part in enumerate(names): + if index in part: + return sources[i] + return None + def unique_dicts(list_of_dict): """Remove duplicate dictionaries from a list.""" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 2b8aab5ef..621ab4541 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,4 +1,3 @@ -import re from contextlib import nullcontext from inspect import getmodule from itertools import chain as iter_chain @@ -25,23 +24,12 @@ def is_merge_table(table): - """Return True if table definition matches the default Merge table. - - Regex removes comments and blank lines before comparison. - """ + """Return True if table fields exactly match Merge table.""" if not isinstance(table, dj.Table): return False - if isinstance(table, dj.FreeTable): - fields, pk = table.heading.names, table.primary_key - return fields == [ - RESERVED_PRIMARY_KEY, - RESERVED_SECONDARY_KEY, - ] and pk == [RESERVED_PRIMARY_KEY] - return MERGE_DEFINITION == re.sub( - r"\n\s*\n", - "\n", - re.sub(r"#.*\n", "\n", getattr(table, "definition", "")), - ) + return table.primary_key == [ + RESERVED_PRIMARY_KEY + ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY] class Merge(dj.Manual): @@ -74,14 +62,8 @@ def __init__(self): ) self._source_class_dict = {} - def _remove_comments(self, definition): - """Use regular expressions to remove comments and blank lines""" - return re.sub( # First remove comments, then blank lines - r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition) - ) - @staticmethod - def _part_name(part=None): + def _part_name(part): """Return the CamelCase name of a part table""" if not isinstance(part, str): part = part.table_name @@ -141,9 +123,6 @@ def _merge_restrict_parts( cls._ensure_dependencies_loaded() - if not restriction: - restriction = True - # Normalize restriction to sql string restr_str = make_condition(cls(), restriction, set()) @@ -387,8 +366,7 @@ def _ensure_dependencies_loaded(cls) -> None: Otherwise parts returns none """ - if not dj.conn.connection.dependencies._loaded: - dj.conn.connection.dependencies.load() + dj.conn.connection.dependencies.load() def insert(self, rows: list, **kwargs): """Merges table specific insert, ensuring data exists in part parents. @@ -783,7 +761,7 @@ def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: "No merge_fetch results.\n\t" + "If not restricting, try: `M.merge_fetch(True,'attr')\n\t" + "If restricting by source, use dict: " - + "`M.merge_fetch({'source':'X'})" + + "`M.merge_fetch({'source':'X'}" ) return results[0] if len(results) == 1 else results @@ -821,6 +799,33 @@ def super_delete(self, warn=True, *args, **kwargs): self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) + # ------------------------------ Restrict by ------------------------------ + + # TODO: TEST THIS + # TODO: Allow (Table & restriction).merge_xxx() syntax + def restrict_from_upstream(self, restriction=True, **kwargs): + """Restrict self based on upstream table.""" + from spyglass.utils.dj_graph import FindKeyGraph + + if restriction is True: + return self._merge_repr() + + graph = FindKeyGraph( + seed_table=self, + restriction=restriction, + leaves=self.parts(), + cascade=True, + verbose=False, + **kwargs, + ) + + self_restrict = [ + leaf.fetch(RESERVED_PRIMARY_KEY, as_dict=True) + for leaf in graph.leaf_ft + ] + + return self & self_restrict + _Merge = Merge @@ -830,10 +835,6 @@ def super_delete(self, warn=True, *args, **kwargs): def delete_downstream_merge( table: dj.Table, - restriction: str = None, - dry_run=True, - recurse_level=2, - disable_warning=False, **kwargs, ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries @@ -858,12 +859,15 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ + logger.warning( + "DEPRECATED: This function will be removed in `0.6`. " + + "Use AnyTable().delete_downstream_merge() instead." + ) + from spyglass.utils.dj_mixin import SpyglassMixin if not isinstance(table, SpyglassMixin): raise ValueError("Input must be a Spyglass Table.") table = table if isinstance(table, dj.Table) else table() - return table.delete_downstream_merge( - restriction=restriction, dry_run=dry_run, **kwargs - ) + return table.delete_downstream_merge(**kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 50e023443..d988a9696 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -230,7 +230,7 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_merge ------------------------ - def import_merge_tables(self): + def _import_merge_tables(self): """Import all merge tables downstream of self.""" from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 @@ -280,7 +280,7 @@ def search_descendants(parent): _ = search_descendants(self) except NetworkXError: try: # Attempt to import missing table - self.import_merge_tables() + self._import_merge_tables() _ = search_descendants(self) except NetworkXError as e: table_name = "".join(e.args[0].split("`")[1:4]) @@ -385,8 +385,7 @@ def delete_downstream_merge( merge_join_dict = {} for name, chain in self._merge_chains.items(): - join = chain.cascade(restriction, direction="down") - if join: + if join := chain.cascade(restriction, direction="down"): merge_join_dict[name] = join if not merge_join_dict and not disable_warning: @@ -766,88 +765,30 @@ def fetch1(self, *args, log_fetch=True, **kwargs): # ------------------------------ Restrict by ------------------------------ - def restrict_from_upstream(self, key, **kwargs): - """Recursive function to restrict a table based on secondary keys of upstream tables""" - return restrict_from_upstream(self, key, **kwargs) - - -def restrict_from_upstream(table, key, max_recursion=3): - """Recursive function to restrict a table based on secondary keys of upstream tables""" - print(f"table: {table.full_table_name}, key: {key}") - # Tables not to recurse through because too big or central - blacklist = [ - "`common_nwbfile`.`analysis_nwbfile`", - ] - - # Case: MERGE table - if (table := table & key) and max_recursion: - if isinstance(table, Merge): - parts = table.parts(as_objects=True) - restricted_parts = [ - restrict_from_upstream(part, key, max_recursion - 1) - for part in parts - ] - # only keep entries from parts that got restricted - restricted_parts = [ - r_part.proj("merge_id") - for r_part, part in zip(restricted_parts, parts) - if ( - not len(r_part) == len(part) - or check_complete_restrict(r_part, key) - ) - ] - # return the merge of the restricted parts - merge_keys = [] - for r_part in restricted_parts: - merge_keys.extend(r_part.fetch("merge_id", as_dict=True)) - return table & merge_keys - - # Case: regular table - upstream_tables = table.parents(as_objects=True) - # prevent a loop where call Merge master table from part - upstream_tables = [ - parent - for parent in upstream_tables - if not ( - isinstance(parent, Merge) - and table.full_table_name in parent.parts() - ) - and (parent.full_table_name not in blacklist) - ] - for parent in upstream_tables: - print(parent.full_table_name) - print(len(parent)) - r_parent = restrict_from_upstream(parent, key, max_recursion - 1) - if len(r_parent) == len(parent): - continue # skip joins with uninformative tables - table = safe_join(table, r_parent) - if check_complete_restrict(table, key) or not table: - print(len(table)) - break - return table - - -def check_complete_restrict(table, key): - """Checks all keys in a restriction dictionary are used in a table""" - if all([k in table.heading.names for k in key.keys()]): - print("FOUND") - return all([k in table.heading.names for k in key.keys()]) - - -# Utility Function -def safe_join(table_1, table_2): - """enables joining of two tables with overlapping secondary keys""" - secondary_1 = [ - name - for name in table_1.heading.names - if name not in table_1.primary_key - ] - secondary_2 = [ - name - for name in table_2.heading.names - if name not in table_2.primary_key - ] - overlap = [name for name in secondary_1 if name in secondary_2] - return table_1 * table_2.proj( - *[name for name in table_2.heading.names if name not in overlap] - ) + def __mod__(self, restriction): + """Restriction by upstream operator e.g. ``q1 % q2``. + + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_from_upstream(restriction) + + def restrict_from_upstream(self, restriction=True, **kwargs): + """Restrict self based on upstream table.""" + from spyglass.utils.dj_graph import FindKeyGraph + + if restriction is True: + return self + + graph = FindKeyGraph( + seed_table=self, + table_name=self.full_table_name, + restriction=restriction, + cascade=True, + verbose=False, + **kwargs, + ) + return graph.leaf_ft[0] diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 43eb70aa9..de7671b42 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -513,7 +513,7 @@ def get_nwb_copy_filename(nwb_file_name): def change_group_permissions( subject_ids, set_group_name, analysis_dir="/stelmo/nwb/analysis" ): - logger.warning("This function is deprecated and will be removed soon.") + logger.warning("DEPRECATED: This function will be removed in `0.6`.") # Change to directory with analysis nwb files os.chdir(analysis_dir) # Get nwb file directories with specified subject ids diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index 2ce8b8677..cb8bbccc4 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -44,24 +44,24 @@ def test_chain_str(chain): def test_chain_repr(chain): """Test that the repr of a TableChain object is as expected.""" repr_got = repr(chain) - repr_ext = "Chain: " + chain._link_symbol.join(chain.names) + repr_ext = "Chain: " + chain._link_symbol.join(chain.path) assert repr_got == repr_ext, "Unexpected repr of TableChain object." def test_chain_len(chain): """Test that the len of a TableChain object is as expected.""" - assert len(chain) == len(chain.names), "Unexpected len of TableChain." + assert len(chain) == len(chain.path), "Unexpected len of TableChain." def test_chain_getitem(chain): """Test getitem of TableChain object.""" by_int = chain[0] - by_str = chain[chain.names[0]] + by_str = chain[chain.path[0]] assert by_int == by_str, "Getitem by int and str not equal." def test_nolink_join(no_link_chain): - assert no_link_chain.join() is None, "Unexpected join of no link chain." + assert no_link_chain.cascade() is None, "Unexpected join of no link chain." def test_chain_str_no_link(no_link_chain): diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 5cfa9e478..8b1d41f4a 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -83,7 +83,7 @@ def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables): )[0] assert len(rft) == 1, "ddm did not return restricted table." - table_name = pos_merge_tables[0].parts()[-1] + table_name = [p for p in pos_merge_tables[0].parts() if "trode" in p][0] assert table_name == rft.full_table_name, "ddm didn't grab right table." assert ( From 05e7660822485a54ff7a3a91bc85852e8305dcff Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 29 Apr 2024 20:07:47 -0500 Subject: [PATCH 07/17] WIP: Handle all alias cases in _bridge_restr --- src/spyglass/utils/dj_graph.py | 1 - src/spyglass/utils/dj_graph_abs.py | 120 ++++++++++++++++---------- src/spyglass/utils/dj_merge_tables.py | 10 +++ 3 files changed, 83 insertions(+), 48 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 461843018..353d1e961 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -295,7 +295,6 @@ def cascade(self, show_progress=None) -> None: replace=True, ) - # TODO: Add restrict_from_upstream method to Merge class, leaf per part def cascade1_search( self, table: str, diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py index 146c7500c..de18e4076 100644 --- a/src/spyglass/utils/dj_graph_abs.py +++ b/src/spyglass/utils/dj_graph_abs.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from itertools import chain as iter_chain +from typing import Dict, List, Tuple, Union from datajoint import FreeTable, logger from datajoint.condition import make_condition from datajoint.table import Table -from networkx import NetworkXNoPath, shortest_path +from datajoint.utils import to_camel_case +from networkx import NetworkXNoPath, all_simple_paths, shortest_path from spyglass.utils.dj_helper_fn import unique_dicts @@ -72,37 +74,37 @@ def _get_attr_dict( for t in self.visited } - def _get_attr_map_btwn(self, child, parent): - """Get attribute map between child and parent. + def _get_edge(self, child, parent) -> Tuple[bool, Dict[str, str]]: + """Get edge data between child and parent. - Currently used for debugging.""" + Returns + ------- + Tuple[bool, Dict[str, str]] + Tuple of boolean indicating direction and edge data. True if child + is child of parent. + """ child = child if isinstance(child, str) else child.full_table_name parent = parent if isinstance(parent, str) else parent.full_table_name - reverse = False - try: - path = shortest_path(self.graph, child, parent) - except NetworkXNoPath: - reverse, child, parent = True, parent, child - path = shortest_path(self.graph, child, parent) - - if len(path) != 2 and not path[1].isnumeric(): - raise ValueError(f"{child} -> {parent} not direct path: {path}") - - try: - attr_map = self.graph[child][path[1]]["attr_map"] - except KeyError: - attr_map = self.graph[path[1]][child]["attr_map"] + if edge := self.graph.get_edge_data(parent, child): + return False, edge + elif edge := self.graph.get_edge_data(child, parent): + return True, edge + + # Handle alias nodes. `shortest_path` doesn't work with aliases + p1 = all_simple_paths(self.graph, child, parent) + p2 = all_simple_paths(self.graph, parent, child) + paths = [p for p in iter_chain(p1, p2)] # list for error handling + for path in paths: + if len(path) > 3 or (len(path) > 2 and not path[1].isnumeric()): + continue + return self._get_edge(path[0], path[1]) - return self._parse_attr_map(attr_map, reverse=reverse) + raise ValueError(f"{child} -> {parent} not direct path: {paths}") - def _parse_attr_map(self, attr_map, reverse=False): + def _rev_attrs(self, attr_map): """Parse attribute map. Remove self-references.""" - if not attr_map: - return {} - if reverse: - return {v: k for k, v in attr_map.items() if k != v} - return {k: v for k, v in attr_map.items() if k != v} + return {v: k for k, v in attr_map.items()} def _get_restr(self, table): """Get restriction from graph node. @@ -154,6 +156,14 @@ def all_ft(self): if not table.isnumeric() ] + def _print_restr(self, leaves=False): + """Print restrictions for each table in visited set.""" + mylist = self.leaves if leaves else self.visited + for table in mylist: + self._log_truncate( + f"{table.split('.')[-1]:>35} {self._get_restr(table)}" + ) + def get_restr_ft(self, table: Union[int, str]): """Get restricted FreeTable from graph node. @@ -183,8 +193,10 @@ def _bridge_restr( table1: str, table2: str, restr: str, + direction: str = None, attr_map: dict = None, - primary: bool = True, + primary: bool = None, + aliased: bool = None, **kwargs, ): """Given two tables and a restriction, return restriction for table2. @@ -214,35 +226,48 @@ def _bridge_restr( List[Dict[str, str]] List of dicts containing primary key fields for restricted table2. """ - ft1 = self._get_ft(table1) & restr + # Direction UP: table1 -> table2, parent -> child + if not all([direction, attr_map, primary, aliased]): + dir_bool, edge = self._get_edge(table1, table2) + direction = "up" if dir_bool else "down" + attr_map = edge.get("attr_map") + primary = edge.get("primary") + aliased = edge.get("aliased") + + ft1 = self._get_ft(table1) + rt1 = ft1 & restr ft2 = self._get_ft(table2) if len(ft1) == 0: return ["False"] - attr_map = self._parse_attr_map(attr_map) - - if table1 in ft2.parents(): - flip = False - child, parent = ft2, ft1 - else: # table2 in ft1.children() - flip = True - child, parent = ft1, ft2 - - if primary: - join = (parent.proj(**attr_map) * child).proj() - else: - join = (parent.proj(..., **attr_map) * child).proj() + adjust = bool(set(attr_map.values()) - set(ft1.heading.names)) + if adjust: + attr_map = self._rev_attrs(attr_map) - if set(ft2.primary_key).isdisjoint(set(join.heading.names)): - join = join.proj(**self._parse_attr_map(attr_map, reverse=True)) + join = rt1.proj(**attr_map) * ft2 ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) - if self.verbose and len(ft2) and len(ret) == len(ft2): - self._log_truncate(f"NULL restr {table2}") - if self.verbose and attr_map: - self._log_truncate(f"attr_map {table1} -> {table2}: {flip}") + null = None + if self.verbose: + dir = "Up" if direction == "up" else "Dn" + prim = "Pri" if primary else "Sec" + adjust = "Flip" if adjust else "NoFp" + aliaa = "Alias" if aliased else "NoAli" + null = ( + "NULL" + if len(ret) == 0 + else "FULL" if len(ft2) == len(ret) else "part" + ) + strt = f"{to_camel_case(table1.table_name)}" + endp = f"{to_camel_case(table2.table_name)}" + self._log_truncate( + f"{dir} {prim} {aliaa} {adjust}: {null} {strt} -> {endp}" + ) + if null and null != "part": + pass + # __import__("pdb").set_trace() return ret @@ -279,6 +304,7 @@ def cascade1(self, table, restriction, direction="up", replace=False): table1=table, table2=next_table, restr=restriction, + direction=direction, **data, ) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 621ab4541..2eba4d0b0 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -2,6 +2,7 @@ from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint +from re import sub as re_sub from time import time from typing import Union @@ -27,6 +28,15 @@ def is_merge_table(table): """Return True if table fields exactly match Merge table.""" if not isinstance(table, dj.Table): return False + if not table.is_declared: + if tbl_def := getattr(table, "definition", None): + return MERGE_DEFINITION == re_sub( + r"\n\s*\n", + "\n", + re_sub(r"#.*\n", "\n", tbl_def.strip()), + ) + logger.warning(f"Cannot determine merge table status for {table}") + return True return table.primary_key == [ RESERVED_PRIMARY_KEY ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY] From 09d6038ad5e5477d0f738413ec3e7bfab0f15e16 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 1 May 2024 12:39:47 -0500 Subject: [PATCH 08/17] WIP: Add tests --- src/spyglass/utils/dj_graph.py | 74 ++++++++++---- src/spyglass/utils/dj_graph_abs.py | 38 +++++-- src/spyglass/utils/dj_mixin.py | 38 +++++-- tests/conftest.py | 97 ++++++++++++------ tests/lfp/conftest.py | 14 --- tests/lfp/test_lfp.py | 5 - tests/utils/__init__.py | 0 tests/utils/conftest.py | 16 +++ tests/utils/schema_graph.py | 155 +++++++++++++++++++++++++++++ tests/utils/test_graph.py | 33 +++++- 10 files changed, 383 insertions(+), 87 deletions(-) create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/schema_graph.py diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 353d1e961..76a3afda3 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -66,7 +66,7 @@ def leaf_ft(self): return [self._get_ft(table, with_restr=True) for table in self.leaves] def add_leaf( - self, table_name=None, restriction=True, cascade=False + self, table_name=None, restriction=True, cascade=False, direction="up" ) -> None: """Add leaf to graph and cascade if requested. @@ -84,11 +84,13 @@ def add_leaf( self.cascaded = False - new_ancestors = set(self._get_ft(table_name).ancestors()) + if direction == "up": + new_visits = set(self._get_ft(table_name).ancestors()) + else: + new_visits = set(self._get_ft(table_name).descendants()) - # TODO: adjust to permit cascade down, conditional to descendants - self.to_visit |= new_ancestors # Add to total ancestors - self.visited -= new_ancestors # Remove from visited to revisit + self.to_visit |= new_visits # Add to total ancestors + self.visited -= new_visits # Remove from visited to revisit self.leaves.add(table_name) self._set_restr(table_name, restriction) # Redundant if cascaded @@ -187,10 +189,10 @@ def _get_files(self, table): def cascade_files(self): """Set node attribute for analysis files.""" for table in self.visited: - ft = self._get_ft(table) + ft = self._get_ft(table, with_restr=True) if not set(self.analysis_pk).issubset(ft.heading.names): continue - files = list((ft & self._get_restr(table)).fetch(*self.analysis_pk)) + files = list(ft.fetch(*self.analysis_pk)) self._set_node(table, "files", files) @property @@ -247,10 +249,17 @@ def __init__( super().__init__(seed_table, verbose=verbose) + self.direction = direction + if restriction and table_name: self._set_find_restr(table_name, restriction) - self.add_leaf(table_name, True, cascade=False) - self.add_leaves(leaves, default_restriction=restriction, cascade=False) + self.add_leaf(table_name, True, cascade=False, direction=direction) + self.add_leaves( + leaves, + default_restriction=restriction, + cascade=False, + direction=direction, + ) all_nodes = set([n for n in self.graph.nodes if not n.isnumeric()]) self.no_visit.update(all_nodes - self.to_visit) # Skip non-ancestors @@ -264,12 +273,12 @@ def _set_find_restr(self, table_name, restriction): """Set restr to look for from leaf node.""" if isinstance(restriction, dict): logger.warning("key_from_upstream: DICT unreliable, use STR.") - restr_attrs = set(restriction.keys()) - else: - restr_attrs = set() # modified by make_condition - table_ft = self._get_ft(table_name) - _ = make_condition(table_ft, restriction, restr_attrs) - self._set_node(table_name, "find_restr", restriction) + + restr_attrs = set() # modified by make_condition + table_ft = self._get_ft(table_name) + restr_string = make_condition(table_ft, restriction, restr_attrs) + + self._set_node(table_name, "find_restr", restr_string) self._set_node(table_name, "restr_attrs", restr_attrs) def _get_find_restr(self, table) -> Tuple[str, Set[str]]: @@ -277,16 +286,31 @@ def _get_find_restr(self, table) -> Tuple[str, Set[str]]: node = self._get_node(table) return node.get("find_restr", False), node.get("restr_attrs", set()) - def add_leaves(self, leaves=None, default_restriction=None, cascade=False): + def add_leaves( + self, + leaves=None, + default_restriction=None, + cascade=False, + direction=None, + ): leaves = self._process_leaves( leaves=leaves, default_restriction=default_restriction ) for leaf in leaves: # Multiple leaves self._set_find_restr(**leaf) - self.add_leaf(leaf["table_name"], True, cascade=False) + self.add_leaf( + leaf["table_name"], + True, + cascade=False, + direction=direction, + ) - def cascade(self, show_progress=None) -> None: + def cascade(self, direction=None, show_progress=None) -> None: + direction = direction or self.direction + if self.cascaded: + return for table in self.leaves: + self._log_truncate(f"Start {table}: {self._get_restr(table)}") restriction, restr_attrs = self._get_find_restr(table) self.cascade1_search( table=table, @@ -294,15 +318,21 @@ def cascade(self, show_progress=None) -> None: restr_attrs=restr_attrs, replace=True, ) + self.cascaded = True def cascade1_search( self, table: str, restriction: str, restr_attrs: Set[str] = None, - direction: str = "up", + direction: str = None, replace: bool = True, ): + self._log_truncate(f"Search {table}: {restriction}") + if self.cascaded: + return + + direction = direction or self.direction next_func = ( self.graph.parents if direction == "up" else self.graph.children ) @@ -312,17 +342,21 @@ def cascade1_search( next_table, data = next_func(next_table).popitem() if next_table in self.no_visit or table == next_table: + self._log_truncate(f"Skip {next_table}: {restriction}") + reason = "no_visit" if next_table in self.no_visit else "same" + self._log_truncate(f"B/C {next_table}: {reason}") continue next_ft = self._get_ft(next_table) if restr_attrs.issubset(set(next_ft.heading.names)): + self._log_truncate(f"Found {next_table}: {restriction}") self.cascade1( table=next_table, restriction=restriction, direction="down" if direction == "up" else "up", replace=replace, ) - return + self.cascaded = True self.cascade1_search( table=next_table, diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py index de18e4076..37756fdf2 100644 --- a/src/spyglass/utils/dj_graph_abs.py +++ b/src/spyglass/utils/dj_graph_abs.py @@ -4,9 +4,11 @@ from datajoint import FreeTable, logger from datajoint.condition import make_condition +from datajoint.dependencies import unite_master_parts from datajoint.table import Table from datajoint.utils import to_camel_case from networkx import NetworkXNoPath, all_simple_paths, shortest_path +from networkx.algorithms.dag import topological_sort from spyglass.utils.dj_helper_fn import unique_dicts @@ -113,7 +115,7 @@ def _get_restr(self, table): in attrs like `all_ft`. """ table = table if isinstance(table, str) else table.full_table_name - return self._get_node(table).get("restr", "False") + return self._get_node(table).get("restr") def _set_restr(self, table, restriction, replace=False): """Add restriction to graph node. If one exists, merge with new.""" @@ -124,8 +126,8 @@ def _set_restr(self, table, restriction, replace=False): else restriction ) existing = self._get_restr(table) - if not replace and existing != "False": # False is default - if existing == restriction: + if not replace and existing: + if restriction == existing: return join = ft & [existing, restriction] if len(join) == len(ft & existing): @@ -134,26 +136,44 @@ def _set_restr(self, table, restriction, replace=False): ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) + self._log_truncate(f"Set {table.split('.')[-1]} {restriction}") + # if "#pk_node" in table: + # __import__("pdb").set_trace() self._set_node(table, "restr", restriction) def _get_ft(self, table, with_restr=False): """Get FreeTable from graph node. If one doesn't exist, create it.""" table = table if isinstance(table, str) else table.full_table_name - restr = self._get_restr(table) if with_restr else True + + if with_restr: + restr = self._get_restr(table) + if not restr: + logger.warning(f"No restriction for {table}") + restr = False + else: + restr = True + if ft := self._get_node(table).get("ft"): return ft & restr ft = FreeTable(self.connection, table) self._set_node(table, "ft", ft) return ft & restr + def topological_sort(self, nodes=None) -> List[str]: + """Get topological sort of visited nodes. From datajoint.diagram""" + nodes = nodes or self.visited + nodes = [n for n in nodes if not n.isnumeric()] + return unite_master_parts( + list(topological_sort(self.graph.subgraph(nodes))) + ) + @property def all_ft(self): """Get restricted FreeTables from all visited nodes.""" self.cascade() return [ self._get_ft(table, with_restr=True) - for table in self.visited - if not table.isnumeric() + for table in self.topological_sort() ] def _print_restr(self, leaves=False): @@ -260,14 +280,14 @@ def _bridge_restr( if len(ret) == 0 else "FULL" if len(ft2) == len(ret) else "part" ) - strt = f"{to_camel_case(table1.table_name)}" - endp = f"{to_camel_case(table2.table_name)}" + strt = f"{to_camel_case(ft1.table_name)}" + endp = f"{to_camel_case(ft2.table_name)}" self._log_truncate( f"{dir} {prim} {aliaa} {adjust}: {null} {strt} -> {endp}" ) if null and null != "part": pass - # __import__("pdb").set_trace() + # __import__("pdb").set_trace() return ret diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index d988a9696..c1d4758a6 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -765,8 +765,8 @@ def fetch1(self, *args, log_fetch=True, **kwargs): # ------------------------------ Restrict by ------------------------------ - def __mod__(self, restriction): - """Restriction by upstream operator e.g. ``q1 % q2``. + def __lshift__(self, restriction): + """Restriction by upstream operator e.g. ``q1 << q2``. Returns ------- @@ -774,21 +774,47 @@ def __mod__(self, restriction): A restricted copy of the query expression using the nearest upstream table for which the restriction is valid. """ - return self.restrict_from_upstream(restriction) + return self.restrict_from(restriction, direction="up") - def restrict_from_upstream(self, restriction=True, **kwargs): + def __rshift__(self, restriction): + """Restriction by downstream operator e.g. ``q1 >> q2``. + + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_from(restriction, direction="down") + + def restrict_from(self, restriction=True, direction="up", **kwargs): + """Restrict self based on upstream table.""" + ret = self.restrict_graph(restriction, direction, **kwargs).leaf_ft[0] + if len(ret) == len(self): + logger.warning("Restriction did not limit table.") + return ret + + def restrict_graph(self, restriction=True, direction="up", **kwargs): """Restrict self based on upstream table.""" from spyglass.utils.dj_graph import FindKeyGraph if restriction is True: return self + try: # Save time if restriction is already valid + ret = self.restrict(restriction) + logger.warning("Restriction valid for this table. Using as is.") + return ret + except DataJointError: + pass # Could avoid try if assert_join_compatible returned a bool + logger.info("Restriction not valid. Attempting to cascade.") graph = FindKeyGraph( seed_table=self, table_name=self.full_table_name, restriction=restriction, + direction=direction, cascade=True, - verbose=False, + verbose=True, **kwargs, ) - return graph.leaf_ft[0] + return graph diff --git a/tests/conftest.py b/tests/conftest.py index 6c7b4e18a..a3d4d681a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,10 @@ +"""Configuration for pytest, including fixtures and command line options. + +Fixtures in this script are mad available to all tests in the test suite. +conftest.py files in subdirectories have fixtures that are only available to +tests in that subdirectory. +""" + import os import sys import warnings @@ -13,11 +20,12 @@ from .container import DockerMySQLManager -# ---------------------- CONSTANTS --------------------- +warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") + +# ------------------------------- TESTS CONFIG ------------------------------- # globals in pytest_configure: # BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD -warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") def pytest_addoption(parser): @@ -131,7 +139,7 @@ def pytest_unconfigure(config): SERVER.stop() -# ------------------- FIXTURES ------------------- +# ---------------------------- FIXTURES, TEST ENV ---------------------------- @pytest.fixture(scope="session") @@ -143,6 +151,27 @@ def verbose(): @pytest.fixture(scope="session", autouse=True) def verbose_context(verbose): """Verbosity context for suppressing Spyglass logging.""" + + class QuietStdOut: + """Used to quiet all prints and logging as context manager.""" + + def __init__(self): + from spyglass.utils import logger as spyglass_logger + + self.spy_logger = spyglass_logger + self.previous_level = None + + def __enter__(self): + self.previous_level = self.spy_logger.getEffectiveLevel() + self.spy_logger.setLevel("CRITICAL") + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + self.spy_logger.setLevel(self.previous_level) + sys.stdout.close() + sys.stdout = self._original_stdout + yield nullcontext() if verbose else QuietStdOut() @@ -193,6 +222,9 @@ def raw_dir(base_dir): yield base_dir / "raw" +# ------------------------------- FIXTURES, DATA ------------------------------- + + @pytest.fixture(scope="session") def mini_path(raw_dir): path = raw_dir / TEST_FILE @@ -264,6 +296,8 @@ def mini_insert(mini_path, teardown, server, load_config): ) from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + _ = SpikeSortingOutput() + LabMember().insert1( ["Root User", "Root", "User"], skip_duplicates=not teardown ) @@ -287,8 +321,7 @@ def mini_insert(mini_path, teardown, server, load_config): yield close_nwb_files() - # Note: no need to run deletes in teardown, since we are using teardown - # will remove the container + # Note: no need to run deletes in teardown, bc removing the container @pytest.fixture(scope="session") @@ -301,6 +334,9 @@ def mini_dict(mini_copy_name): yield {"nwb_file_name": mini_copy_name} +# --------------------------- FIXTURES, SUBMODULES --------------------------- + + @pytest.fixture(scope="session") def common(dj_conn): from spyglass import common @@ -322,6 +358,27 @@ def settings(dj_conn): yield settings +@pytest.fixture(scope="session") +def sgp(common): + from spyglass import position + + yield position + + +@pytest.fixture(scope="session") +def lfp(common): + from spyglass import lfp + + return lfp + + +@pytest.fixture(scope="session") +def lfp_band(lfp): + from spyglass.lfp.analysis.v1 import lfp_band + + return lfp_band + + @pytest.fixture(scope="session") def populate_exception(): from spyglass.common.errors import PopulateException @@ -329,11 +386,7 @@ def populate_exception(): yield PopulateException -@pytest.fixture(scope="session") -def sgp(common): - from spyglass import position - - yield position +# ------------------------- FIXTURES, POSITION TABLES ------------------------- @pytest.fixture(scope="session") @@ -446,25 +499,9 @@ def pos_merge_key(pos_merge, trodes_pos_v1, trodes_sel_keys): yield pos_merge.merge_get_part(trodes_sel_keys[-1]).fetch1("KEY") -# ------------------ GENERAL FUNCTION ------------------ - - -class QuietStdOut: - """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" - - def __init__(self): - from spyglass.utils import logger as spyglass_logger - - self.spy_logger = spyglass_logger - self.previous_level = None +# --------------------------- FIXTURES, LFP TABLES --------------------------- - def __enter__(self): - self.previous_level = self.spy_logger.getEffectiveLevel() - self.spy_logger.setLevel("CRITICAL") - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - def __exit__(self, exc_type, exc_val, exc_tb): - self.spy_logger.setLevel(self.previous_level) - sys.stdout.close() - sys.stdout = self._original_stdout +@pytest.fixture(scope="module") +def lfp_band_v1(lfp_band): + yield lfp_band.LFPBandV1() diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index 354803493..e318610ec 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -3,20 +3,6 @@ from pynwb import NWBHDF5IO -@pytest.fixture(scope="session") -def lfp(common): - from spyglass import lfp - - return lfp - - -@pytest.fixture(scope="session") -def lfp_band(lfp): - from spyglass.lfp.analysis.v1 import lfp_band - - return lfp_band - - @pytest.fixture(scope="session") def firfilters_table(common): return common.FirFilterParameters() diff --git a/tests/lfp/test_lfp.py b/tests/lfp/test_lfp.py index 51b2e96f4..b496ae445 100644 --- a/tests/lfp/test_lfp.py +++ b/tests/lfp/test_lfp.py @@ -37,11 +37,6 @@ def test_lfp_band_dataframe(lfp_band_analysis_raw, lfp_band, lfp_band_key): assert df_raw.equals(df_fetch), "LFPBand dataframe not match." -@pytest.fixture(scope="module") -def lfp_band_v1(lfp_band): - yield lfp_band.LFPBandV1() - - def test_lfp_band_compute_signal_invalid(lfp_band_v1): with pytest.raises(ValueError): lfp_band_v1.compute_analytic_signal([4]) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index b2a2dcd2a..cfc63b5c0 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -1,6 +1,8 @@ import datajoint as dj import pytest +from . import schema_graph + @pytest.fixture(scope="module") def merge_table(pos_merge_tables): @@ -56,3 +58,17 @@ def no_link_chain(Nwbfile): from spyglass.utils.dj_chains import TableChain yield TableChain(Nwbfile, InsertError()) + + +@pytest.fixture(scope="module") +def graph_tables(dj_conn): + schema = dj.Schema(context=schema_graph.LOCALS_GRAPH) + + for table in schema_graph.LOCALS_GRAPH.values(): + schema(table) + + schema.activate("test_graph", connection=dj_conn) + + yield schema_graph.LOCALS_GRAPH + + schema.drop(force=True) diff --git a/tests/utils/schema_graph.py b/tests/utils/schema_graph.py new file mode 100644 index 000000000..d646f4f8b --- /dev/null +++ b/tests/utils/schema_graph.py @@ -0,0 +1,155 @@ +from inspect import isclass as inspect_isclass + +import datajoint as dj + +from spyglass.utils import SpyglassMixin + +# Ranges are offset from one another to create unique list of entries for each +# table while respecting the foreign key constraints. + +parent_id = range(10) +parent_attr = [i + 10 for i in range(2, 12)] + +other_id = range(9) +other_attr = [i + 10 for i in range(3, 12)] + +intermediate_id = range(2, 10) +intermediate_attr = [i + 10 for i in range(4, 12)] + +pk_id = range(3, 10) +pk_attr = [i + 10 for i in range(5, 12)] + +sk_id = range(6) +sk_attr = [i + 10 for i in range(6, 12)] + +pk_sk_id = range(5) +pk_sk_attr = [i + 10 for i in range(7, 12)] + +pk_alias_id = range(4) +pk_alias_attr = [i + 10 for i in range(8, 12)] + +sk_alias_id = range(3) +sk_alias_attr = [i + 10 for i in range(9, 12)] + + +def offset(gen, offset): + return list(gen)[offset:] + + +class ParentNode(SpyglassMixin, dj.Lookup): + definition = """ + parent_id: int + --- + parent_attr : int + """ + contents = [(i, j) for i, j in zip(parent_id, parent_attr)] + + +class OtherParentNode(SpyglassMixin, dj.Lookup): + definition = """ + other_id: int + --- + other_attr : int + """ + contents = [(i, j) for i, j in zip(other_id, other_attr)] + + +class IntermediateNode(SpyglassMixin, dj.Lookup): + definition = """ + intermediate_id: int + --- + -> ParentNode + intermediate_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip( + intermediate_id, offset(parent_id, 1), intermediate_attr + ) + ] + + +class PkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_id: int + -> IntermediateNode + --- + pk_attr : int + """ + contents = [ + (i, j, k) for i, j, k in zip(pk_id, offset(intermediate_id, 2), pk_attr) + ] + + +class SkNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_id: int + --- + -> IntermediateNode + sk_attr : int + """ + contents = [ + (i, j, k) for i, j, k in zip(sk_id, offset(intermediate_id, 3), sk_attr) + ] + + +class PkSkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_sk_id: int + -> IntermediateNode + --- + -> OtherParentNode + pk_sk_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_sk_id, offset(intermediate_id, 4), other_id, pk_sk_attr + ) + ] + + +class PkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_alias_id: int + -> PkNode.proj(fk_pk_id='pk_id') + --- + pk_alias_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_alias_id, + offset(pk_id, 1), + offset(intermediate_id, 3), + pk_alias_attr, + ) + ] + + +class SkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_alias_id: int + --- + -> SkNode.proj(fk_sk_id='sk_id') + -> PkSkNode + sk_alias_attr : int + """ + contents = [ + (i, j, k, m, n) + for i, j, k, m, n in zip( + sk_alias_id, + offset(sk_id, 2), + offset(pk_sk_id, 1), + offset(intermediate_id, 5), + sk_alias_attr, + ) + ] + + +LOCALS_GRAPH = { + k: v + for k, v in locals().items() + if inspect_isclass(v) and k != "SpyglassMixin" +} +__all__ = list(LOCALS_GRAPH) diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 8d3f8a699..d86c81b22 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -1,7 +1,7 @@ -from pathlib import Path - import pytest +from . import schema_graph as sg + @pytest.fixture(scope="session") def leaf(lin_merge): @@ -45,7 +45,6 @@ def test_rg_file_paths(restr_graph): """Test collection of upstream file paths.""" paths = [p.get("file_path") for p in restr_graph.file_paths] assert len(paths) == 1, "Unexpected number of file paths." - assert all([Path(p).exists() for p in paths]), "Not all file paths exist." @pytest.fixture(scope="session") @@ -68,3 +67,31 @@ def test_add_leaf_restr_ft(restr_graph_new_leaf): restr_graph_new_leaf.cascade() ft = restr_graph_new_leaf.get_restr_ft("`common_interval`.`interval_list`") assert len(ft) == 2, "Unexpected restricted table length." + + +@pytest.mark.parametrize( + "restr, expect_n, msg", + [ + ("pk_attr > 16", 4, "pk down, no alias"), + ("sk_attr > 17", 3, "sk down, no alias"), + ("pk_alias_attr > 18", 3, "pk down, pk alias"), + ("sk_alias_attr > 19", 2, "sk down, sk alias"), + ], +) +def test_restr_from_upstream(graph_tables, restr, expect_n, msg): + msg = "Error in `>>` for " + msg + assert len(graph_tables["ParentNode"]() >> restr) == expect_n, msg + + +@pytest.mark.parametrize( + "table, restr, expect_n, msg", + [ + ("PkNode", "parent_attr > 15", 5, "pk up, no alias"), + ("SkNode", "parent_attr > 16", 4, "sk up, no alias"), + ("PkAliasNode", "parent_attr > 17", 2, "pk up, pk alias"), + ("SkAliasNode", "parent_attr > 18", 2, "sk up, sk alias"), + ], +) +def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): + msg = "Error in `<<` for " + msg + assert len(graph_tables[table]() << restr) == expect_n, msg From aa47ddcf530b47b3b1b48e428f0e6e90a96d7559 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 2 May 2024 12:41:18 -0500 Subject: [PATCH 09/17] WIP: Cascade through merge tables --- docs/src/misc/mixin.md | 42 ++- src/spyglass/utils/dj_chains.py | 6 +- src/spyglass/utils/dj_graph.py | 475 ++++++++++++++++++++++++-- src/spyglass/utils/dj_graph_abs.py | 336 ------------------ src/spyglass/utils/dj_merge_tables.py | 83 ++++- src/spyglass/utils/dj_mixin.py | 65 +++- tests/utils/conftest.py | 15 +- tests/utils/schema_graph.py | 28 +- tests/utils/test_graph.py | 44 ++- 9 files changed, 675 insertions(+), 419 deletions(-) delete mode 100644 src/spyglass/utils/dj_graph_abs.py diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 747a12f9f..1742d875d 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -4,6 +4,7 @@ The Spyglass Mixin provides a way to centralize all Spyglass-specific functionalities that have been added to DataJoint tables. This includes... - Fetching NWB files +- Long-distance restrictions. - Delete functionality, including permission checks and part/master pairs - Export logging. See [export doc](export.md) for more information. @@ -11,16 +12,14 @@ To add this functionality to your own tables, simply inherit from the mixin: ```python import datajoint as dj + from spyglass.utils import SpyglassMixin -schema = dj.schema('my_schema') +schema = dj.schema("my_schema") -@schema -class MyOldTable(dj.Manual): - pass @schema -class MyNewTable(SpyglassMixin, dj.Manual):) +class MyOldTable(dj.Manual): pass ``` @@ -44,6 +43,39 @@ should be fetched from `Nwbfile` or an analysis file should be fetched from `AnalysisNwbfile`. If neither is foreign-key-referenced, the function will refer to a `_nwb_table` attribute. +## Long-Distance Restrictions + +In complicated pipelines like Spyglass, there are often tables that 'bury' their +foreign keys as secondary keys. This is done to avoid having to pass a long list +of foreign keys through the pipeline, potentially hitting SQL limits (see also +[Merge Tables](./merge_tables.md)). This burrying makes it difficult to restrict +a given table by familiar attributes. + +Spyglass provides a function, `restrict_by`, to handle this. The function takes +your restriction and checks parents/children until the restriction can be +applied. Spyglass introduces `<<` as a shorthand for `restrict_by` an upstream +key and `>>` as a shorthand for `restrict_by` a downstream key. + +```python +from spyglass.example import AnyTable + +AnyTable >> 'downsteam_attribute="value"' +AnyTable << 'upstream_attribute="value"' + +# Equivalent to +AnyTable.restrict_by('upstream_attribute="value"', direction="up") +AnyTable.restrict_by('downsteam_attribute="value"', direction="down") +``` + +Some caveats to this function: + +1. 'Peripheral' tables, like `IntervalList` and `AnalysisNwbfile` make it hard + to determine the correct parent/child relationship and have been removed + from this search. +2. This function will raise an error if it attempts to check a table that has + not been imported into the current namespace. It is best used for exploring + and debugging, not for production code. + ## Delete Functionality The mixin overrides the default `delete` function to provide two additional diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index 18a3d11ac..214e0ce4e 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -7,7 +7,7 @@ from datajoint.table import Table from datajoint.utils import to_camel_case -from spyglass.utils.dj_graph_abs import AbstractGraph +from spyglass.utils.dj_graph import AbstractGraph from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, fuzzy_get from spyglass.utils.dj_merge_tables import is_merge_table @@ -228,9 +228,9 @@ def find_path(self, directed=True) -> OrderedDict: try: path = nx.shortest_path(search_graph, source, target) except nx.NetworkXNoPath: - return None + return None # No path found, parent func may do undirected search except nx.NodeNotFound: - self._searched = True + self._searched = True # No path found, don't search again return None ignore_nodes = self.graph.nodes - set(path) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 76a3afda3..0d90e99aa 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,18 +3,375 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from typing import Dict, List, Set, Tuple, Union +from abc import ABC, abstractmethod +from collections.abc import KeysView +from enum import Enum +from itertools import chain as iter_chain +from typing import Any, Dict, List, Set, Tuple, Union +from datajoint import FreeTable, Table from datajoint.condition import make_condition -from datajoint.table import Table +from datajoint.dependencies import unite_master_parts +from datajoint.utils import get_master, to_camel_case +from networkx import all_simple_paths +from networkx.algorithms.dag import topological_sort from tqdm import tqdm -from spyglass.common import AnalysisNwbfile from spyglass.utils import logger -from spyglass.utils.dj_graph_abs import AbstractGraph from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, unique_dicts +class Direction(Enum): + """Cascade direction enum.""" + + UP = "up" + DOWN = "down" + + +class AbstractGraph(ABC): + """Abstract class for graph traversal and restriction application. + + Inherited by... + - RestrGraph: Cascade restriction(s) through a graph + - FindKeyGraph: Iherits from RestrGraph. Cascades through the graph to + find where a restriction works, and cascades back across visited + nodes. + - TableChain: Takes parent and child nodes, finds the shortest path, + and applies a restriction across the path. + + Methods + ------- + cascade: Abstract method implemented by child classes + cascade1: Cascade a restriction up/down the graph, recursively + + Properties + ---------- + all_ft: Get all FreeTables for visited nodes with restrictions applied. + as_dict: Get visited nodes as a list of dictionaries of + {table_name: restriction} + """ + + def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): + """Initialize graph and connection. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + verbose : bool, optional + Whether to print verbose output. Default False + """ + self.connection = seed_table.connection + self.graph = seed_table.connection.dependencies + self.graph.load() + + self.verbose = verbose + self.leaves = set() + self.visited = set() + self.to_visit = set() + self.no_visit = set() + self.cascaded = False + + @abstractmethod + def cascade(self): + """Cascade restrictions through graph.""" + raise NotImplementedError("Child class mut implement `cascade` method") + + def _log_truncate(self, log_str: str, max_len: int = 80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) + + def _ensure_name(self, table: Union[str, Table]) -> str: + """Ensure table is a string.""" + return table if isinstance(table, str) else table.full_table_name + + def _get_node(self, table: Union[str, Table]): + """Get node from graph.""" + table = self._ensure_name(table) + if not (node := self.graph.nodes.get(table)): + raise ValueError( + f"Table {table} not found in graph." + + "\n\tPlease import this table and rerun" + ) + return node + + def _set_node(self, table, attr: str = "ft", value: Any = None): + """Set attribute on node. General helper for various attributes.""" + _ = self._get_node(table) # Ensure node exists + self.graph.nodes[table][attr] = value + + def _get_edge(self, child: str, parent: str) -> Tuple[bool, Dict[str, str]]: + """Get edge data between child and parent. + + Used as a fallback for _bridge_restr. Not required in typical use. + + Returns + ------- + Tuple[bool, Dict[str, str]] + Tuple of boolean indicating direction and edge data. True if child + is child of parent. + """ + child = self._ensure_name(child) + parent = self._ensure_name(parent) + + if edge := self.graph.get_edge_data(parent, child): + return False, edge + elif edge := self.graph.get_edge_data(child, parent): + return True, edge + + # Handle alias nodes. `shortest_path` doesn't work with aliases + p1 = all_simple_paths(self.graph, child, parent) + p2 = all_simple_paths(self.graph, parent, child) + paths = [p for p in iter_chain(p1, p2)] # list for error handling + for path in paths: # Ignore long and non-alias paths + if len(path) > 3 or (len(path) > 2 and not path[1].isnumeric()): + continue + return self._get_edge(path[0], path[1]) + + raise ValueError(f"{child} -> {parent} not direct path: {paths}") + + def _get_restr(self, table): + """Get restriction from graph node.""" + return self._get_node(self._ensure_name(table)).get("restr") + + def _set_restr(self, table, restriction, replace=False): + """Add restriction to graph node. If one exists, merge with new.""" + ft = self._get_ft(table) + restriction = ( # Convert to condition if list or dict + make_condition(ft, restriction, set()) + if not isinstance(restriction, str) + else restriction + ) + existing = self._get_restr(table) + if not replace and existing: + if restriction == existing: + return + join = ft & [existing, restriction] + if len(join) == len(ft & existing): + return # restriction is a subset of existing + restriction = make_condition( + ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() + ) + + self._set_node(table, "restr", restriction) + + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + table = self._ensure_name(table) + if with_restr: + if not (restr := self._get_restr(table) or False): + self._log_truncate(f"No restriction for {table}") + else: + restr = True + + if not (ft := self._get_node(table).get("ft")): + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) + + return ft & restr + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes. + + Topological sort logic adopted from datajoint.diagram. + """ + self.cascade() + nodes = [n for n in self.visited if not n.isnumeric()] + sorted_nodes = unite_master_parts( + list(topological_sort(self.graph.subgraph(nodes))) + ) + all_ft = [ + self._get_ft(table, with_restr=True) for table in sorted_nodes + ] + return [ft for ft in all_ft if len(ft) > 0] + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.visited + if self._get_restr(table) + ] + + def _bridge_restr( + self, + table1: str, + table2: str, + restr: str, + direction: Direction = None, + attr_map: dict = None, + aliased: bool = None, + **kwargs, + ): + """Given two tables and a restriction, return restriction for table2. + + Similar to ((table1 & restr) * table2).fetch(*table2.primary_key) + but with the ability to resolve aliases across tables. One table should + be the parent of the other. If direction or attr_map are not provided, + they will be inferred from the graph. + + Parameters + ---------- + table1 : str + Table name. Restriction always applied to this table. + table2 : str + Table name. Restriction pulled from this table. + restr : str + Restriction to apply to table1. + direction : Direction, optional + Direction to cascade. Default None. + attr_map : dict, optional + dictionary mapping aliases across tables, as pulled from + DataJoint-assembled graph. Default None. + + + Returns + ------- + List[Dict[str, str]] + List of dicts containing primary key fields for restricted table2. + """ + if not all([direction, attr_map]): + dir_bool, edge = self._get_edge(table1, table2) + direction = "up" if dir_bool else "down" + attr_map = edge.get("attr_map") + + ft1 = self._get_ft(table1) & restr + ft2 = self._get_ft(table2) + + if len(ft1) == 0: + return ["False"] + + if rev_attr := bool(set(attr_map.values()) - set(ft1.heading.names)): + attr_map = {v: k for k, v in attr_map.items()} # reverse + + join = ft1.proj(**attr_map) * ft2 + ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) + + if self.verbose: # For debugging. Not required for typical use. + partial = ( + "NULL" + if len(ret) == 0 + else "FULL" if len(ft2) == len(ret) else "part" + ) + flipped = "Fliped" if rev_attr else "NoFlip" + dir = "Up" if direction == "up" else "Dn" + strt = f"{to_camel_case(ft1.table_name)}" + endp = f"{to_camel_case(ft2.table_name)}" + self._log_truncate( + f"{partial} {dir} {flipped}: {strt} -> {endp}, {len(ret)}" + ) + + return ret + + def _camel(self, table): + if isinstance(table, KeysView): + table = list(table) + if not isinstance(table, list): + table = [table] + ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] + return ret[0] if len(ret) == 1 else ret + + def cascade1( + self, + table: str, + restriction: str, + direction: Direction = "up", + replace=False, + count=0, + **kwargs, + ): + """Cascade a restriction up the graph, recursively on parents/children. + + Parameters + ---------- + table : str + Table name + restriction : str + Restriction to apply + direction : Direction, optional + Direction to cascade. Default 'up' + replace : bool, optional + Replace existing restriction. Default False + """ + if count > 100: + raise RecursionError("Cascade1: Recursion limit reached.") + + self._set_restr(table, restriction, replace=replace) + self.visited.add(table) + + G = self.graph + next_func = G.parents if direction == "up" else G.children + dir_dict = {"direction": direction} + dir_name = "Parents" if direction == "up" else "Children" + + # Master/Parts added will go in opposite direction for one link. + # Direction is intentionally not passed to _bridge_restr in this case. + if direction == "up": + next_tables = { + k: {**v, **dir_dict} for k, v in next_func(table).items() + } + next_tables.update( + {part: {} for part in self._get_ft(table).parts()} + ) + else: + next_tables = { + k: {**v, **dir_dict} for k, v in next_func(table).items() + } + if (master_name := get_master(table)) != "": + next_tables[master_name] = {} + + log_dict = { + "Table ": self._camel(table), + f"{dir_name}": self._camel(next_func(table).keys()), + "Parts ": self._camel(self._get_ft(table).parts()), + "Master ": self._camel(get_master(table)), + } + logger.info( + f"Cascade1: {count}\n\t\t\t " + + "\n\t\t\t ".join(f"{k}: {v}" for k, v in log_dict.items()) + ) + for next_table, data in next_tables.items(): + if next_table.isnumeric(): # Skip alias nodes + next_table, data = next_func(next_table).popitem() + + if ( + next_table in self.visited + or next_table in self.no_visit # Subclasses can set this + or table == next_table + ): + path = f"{self._camel(table)} -> {self._camel(next_table)}" + if next_table in self.visited: + self._log_truncate(f"SkipVist: {path}") + if next_table in self.no_visit: + self._log_truncate(f"NoVisit : {path}") + if table == next_table: + self._log_truncate(f"Self : {path}") + + continue + + next_restr = self._bridge_restr( + table1=table, + table2=next_table, + restr=restriction, + **data, + ) + + self.cascade1( + table=next_table, + restriction=next_restr, + direction=direction, + replace=replace, + count=count + 1, + ) + + class RestrGraph(AbstractGraph): def __init__( self, @@ -22,12 +379,19 @@ def __init__( table_name: str = None, restriction: str = None, leaves: List[Dict[str, str]] = None, + direction: Direction = "up", cascade: bool = False, verbose: bool = False, **kwargs, ): """Use graph to cascade restrictions up from leaves to all ancestors. + 'Leaves' are nodes with restrictions applied. Restrictions are cascaded + up/down the graph to all ancestors/descendants. If cascade is desired + in both direction, leaves/cascades should be added and run separately. + Future development could allow for direction setting on a per-leaf + basis. + Parameters ---------- seed_table : Table @@ -39,6 +403,8 @@ def __init__( leaves : Dict[str, str], optional List of dictionaries with keys table_name and restriction. One entry per leaf node. Default None. + direction : Direction, optional + Direction to cascade. Default 'up' cascade : bool, optional Whether to cascade restrictions up the graph on initialization. Default False @@ -47,13 +413,13 @@ def __init__( """ super().__init__(seed_table, verbose=verbose) - self.analysis_pk = AnalysisNwbfile().primary_key - - self.add_leaf(table_name=table_name, restriction=restriction) + self.add_leaf( + table_name=table_name, restriction=restriction, direction=direction + ) self.add_leaves(leaves) if cascade: - self.cascade() + self.cascade(direction=direction) def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" @@ -84,10 +450,11 @@ def add_leaf( self.cascaded = False - if direction == "up": - new_visits = set(self._get_ft(table_name).ancestors()) - else: - new_visits = set(self._get_ft(table_name).descendants()) + new_visits = ( + set(self._get_ft(table_name).ancestors()) + if direction == "up" + else set(self._get_ft(table_name).descendants()) + ) self.to_visit |= new_visits # Add to total ancestors self.visited -= new_visits # Remove from visited to revisit @@ -150,7 +517,7 @@ def add_leaves( if cascade: self.cascade() - def cascade(self, show_progress=None) -> None: + def cascade(self, show_progress=None, direction="up") -> None: """Cascade all restrictions up the graph. Parameters @@ -171,7 +538,7 @@ def cascade(self, show_progress=None) -> None: ): restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") - self.cascade1(table, restr) + self.cascade1(table, restr, direction=direction) if not self.visited == self.to_visit: raise RuntimeError( "Cascade: FAIL - incomplete cascade. Please post issue." @@ -195,14 +562,26 @@ def cascade_files(self): files = list(ft.fetch(*self.analysis_pk)) self._set_node(table, "files", files) + @property + def analysis_file_tbl(self) -> Table: + """Return the analysis file table. Avoids circular import.""" + from spyglass.common import AnalysisNwbfile + + return AnalysisNwbfile() + + @property + def analysis_pk(self) -> List[str]: + """Return primary key fields from analysis file table.""" + return self.analysis_file_tbl.primary_key + @property def file_dict(self) -> Dict[str, List[str]]: """Return dictionary of analysis files from all visited nodes. - Currently unused, but could be useful for debugging. + Included for debugging, to associate files with tables. """ self.cascade() - return self._get_attr_dict("files", default_factory=lambda: []) + return {t: self._get_node(t).get("files", []) for t in self.visited} @property def file_paths(self) -> List[str]: @@ -213,7 +592,7 @@ def file_paths(self) -> List[str]: """ self.cascade() return [ - {"file_path": AnalysisNwbfile().get_abs_path(file)} + {"file_path": self.analysis_file_tbl.get_abs_path(file)} for file in set( [f for files in self.file_dict.values() for f in files] ) @@ -228,7 +607,7 @@ def __init__( table_name: str = None, restriction: str = None, leaves: List[Dict[str, str]] = None, - direction: str = "up", + direction: Direction = "up", cascade: bool = False, verbose: bool = False, **kwargs, @@ -250,6 +629,8 @@ def __init__( super().__init__(seed_table, verbose=verbose) self.direction = direction + self.searched = set() + self.found = False if restriction and table_name: self._set_find_restr(table_name, restriction) @@ -261,8 +642,6 @@ def __init__( direction=direction, ) - all_nodes = set([n for n in self.graph.nodes if not n.isnumeric()]) - self.no_visit.update(all_nodes - self.to_visit) # Skip non-ancestors self.no_visit.update(PERIPHERAL_TABLES) if cascade and restriction: @@ -272,7 +651,7 @@ def __init__( def _set_find_restr(self, table_name, restriction): """Set restr to look for from leaf node.""" if isinstance(restriction, dict): - logger.warning("key_from_upstream: DICT unreliable, use STR.") + logger.warning("Using `>>` or `<<`: DICT unreliable, use STR.") restr_attrs = set() # modified by make_condition table_ft = self._get_ft(table_name) @@ -310,7 +689,6 @@ def cascade(self, direction=None, show_progress=None) -> None: if self.cascaded: return for table in self.leaves: - self._log_truncate(f"Start {table}: {self._get_restr(table)}") restriction, restr_attrs = self._get_find_restr(table) self.cascade1_search( table=table, @@ -319,49 +697,82 @@ def cascade(self, direction=None, show_progress=None) -> None: replace=True, ) self.cascaded = True + if not self.found: + searched = "parents" if direction == "up" else "children" + logger.warning( + f"Restriction could not be applied to any {searched}.\n\t" + + f"From: {self.leaves}\n\t" + + f"Restr: {restriction}" + ) + + def _ban_unsearched(self): + """After found match, ignore others for cascade back to leaf.""" + all_tables = set([n for n in self.graph.nodes]) + unsearched = all_tables - self.searched + camel_searched = self._camel(list(self.searched)) + logger.info(f"Searched: {camel_searched}") + self.no_visit.update(unsearched) def cascade1_search( self, - table: str, - restriction: str, + table: str = None, + restriction: str = True, restr_attrs: Set[str] = None, - direction: str = None, + direction: Direction = None, replace: bool = True, + limit: int = 100, ): - self._log_truncate(f"Search {table}: {restriction}") - if self.cascaded: + if self.found or not table or limit < 1 or table in self.searched: return + self.searched.add(table) + direction = direction or self.direction next_func = ( self.graph.parents if direction == "up" else self.graph.children ) + next_searches = set() for next_table, data in next_func(table).items(): + self._log_truncate( + f"Search: {self._camel(table)} -> {self._camel(next_table)}" + ) if next_table.isnumeric(): next_table, data = next_func(next_table).popitem() if next_table in self.no_visit or table == next_table: - self._log_truncate(f"Skip {next_table}: {restriction}") - reason = "no_visit" if next_table in self.no_visit else "same" - self._log_truncate(f"B/C {next_table}: {reason}") continue next_ft = self._get_ft(next_table) if restr_attrs.issubset(set(next_ft.heading.names)): - self._log_truncate(f"Found {next_table}: {restriction}") + self.searched.add(next_table) + # self.searched.add(get_master(next_table)) + self.searched.update(next_ft.parts()) + self.found = True + self._ban_unsearched() self.cascade1( table=next_table, restriction=restriction, direction="down" if direction == "up" else "up", replace=replace, + **data, ) - self.cascaded = True + return + + next_searches.update( + set([*next_ft.parts(), get_master(next_table), next_table]) + ) + for next_table in next_searches: + if not next_table: + continue # Skip None from get_master self.cascade1_search( table=next_table, restriction=restriction, restr_attrs=restr_attrs, direction=direction, replace=replace, + limit=limit - 1, ) + if self.found: + return diff --git a/src/spyglass/utils/dj_graph_abs.py b/src/spyglass/utils/dj_graph_abs.py deleted file mode 100644 index 37756fdf2..000000000 --- a/src/spyglass/utils/dj_graph_abs.py +++ /dev/null @@ -1,336 +0,0 @@ -from abc import ABC, abstractmethod -from itertools import chain as iter_chain -from typing import Dict, List, Tuple, Union - -from datajoint import FreeTable, logger -from datajoint.condition import make_condition -from datajoint.dependencies import unite_master_parts -from datajoint.table import Table -from datajoint.utils import to_camel_case -from networkx import NetworkXNoPath, all_simple_paths, shortest_path -from networkx.algorithms.dag import topological_sort - -from spyglass.utils.dj_helper_fn import unique_dicts - - -class AbstractGraph(ABC): - def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): - """Abstract class for graph traversal and restriction application. - - Parameters - ---------- - seed_table : Table - Table to use to establish connection and graph - verbose : bool, optional - Whether to print verbose output. Default False - """ - self.connection = seed_table.connection - self.graph = seed_table.connection.dependencies - self.graph.load() - - self.verbose = verbose - self.leaves = set() - self.visited = set() - self.to_visit = set() - self.no_visit = set() - self.cascaded = False - - def _log_truncate(self, log_str, max_len=80): - """Truncate log lines to max_len and print if verbose.""" - if not self.verbose: - return - logger.info( - log_str[:max_len] + "..." if len(log_str) > max_len else log_str - ) - - @abstractmethod - def cascade(self): - """Cascade restrictions through graph.""" - raise NotImplementedError("Child class mut implement `cascade` method") - - def _get_node(self, table): - """Get node from graph.""" - if not isinstance(table, str): - table = table.full_table_name - if not (node := self.graph.nodes.get(table)): - raise ValueError( - f"Table {table} not found in graph." - + "\n\tPlease import this table and rerun" - ) - return node - - def _set_node(self, table, attr="ft", value=None): - """Set attribute on node. General helper for various attributes.""" - _ = self._get_node(table) # Ensure node exists - self.graph.nodes[table][attr] = value - - def _get_attr_dict( - self, attr, default_factory=lambda: None - ) -> List[Dict[str, str]]: - """Get given attr for each table in self.visited - - Uses factory to create default value for missing attributes. - """ - return { - t: self._get_node(t).get(attr, default_factory()) - for t in self.visited - } - - def _get_edge(self, child, parent) -> Tuple[bool, Dict[str, str]]: - """Get edge data between child and parent. - - Returns - ------- - Tuple[bool, Dict[str, str]] - Tuple of boolean indicating direction and edge data. True if child - is child of parent. - """ - child = child if isinstance(child, str) else child.full_table_name - parent = parent if isinstance(parent, str) else parent.full_table_name - - if edge := self.graph.get_edge_data(parent, child): - return False, edge - elif edge := self.graph.get_edge_data(child, parent): - return True, edge - - # Handle alias nodes. `shortest_path` doesn't work with aliases - p1 = all_simple_paths(self.graph, child, parent) - p2 = all_simple_paths(self.graph, parent, child) - paths = [p for p in iter_chain(p1, p2)] # list for error handling - for path in paths: - if len(path) > 3 or (len(path) > 2 and not path[1].isnumeric()): - continue - return self._get_edge(path[0], path[1]) - - raise ValueError(f"{child} -> {parent} not direct path: {paths}") - - def _rev_attrs(self, attr_map): - """Parse attribute map. Remove self-references.""" - return {v: k for k, v in attr_map.items()} - - def _get_restr(self, table): - """Get restriction from graph node. - - Defaults to False if no restriction is set so that it doesn't appear - in attrs like `all_ft`. - """ - table = table if isinstance(table, str) else table.full_table_name - return self._get_node(table).get("restr") - - def _set_restr(self, table, restriction, replace=False): - """Add restriction to graph node. If one exists, merge with new.""" - ft = self._get_ft(table) - restriction = ( # Convert to condition if list or dict - make_condition(ft, restriction, set()) - if not isinstance(restriction, str) - else restriction - ) - existing = self._get_restr(table) - if not replace and existing: - if restriction == existing: - return - join = ft & [existing, restriction] - if len(join) == len(ft & existing): - return # restriction is a subset of existing - restriction = make_condition( - ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() - ) - - self._log_truncate(f"Set {table.split('.')[-1]} {restriction}") - # if "#pk_node" in table: - # __import__("pdb").set_trace() - self._set_node(table, "restr", restriction) - - def _get_ft(self, table, with_restr=False): - """Get FreeTable from graph node. If one doesn't exist, create it.""" - table = table if isinstance(table, str) else table.full_table_name - - if with_restr: - restr = self._get_restr(table) - if not restr: - logger.warning(f"No restriction for {table}") - restr = False - else: - restr = True - - if ft := self._get_node(table).get("ft"): - return ft & restr - ft = FreeTable(self.connection, table) - self._set_node(table, "ft", ft) - return ft & restr - - def topological_sort(self, nodes=None) -> List[str]: - """Get topological sort of visited nodes. From datajoint.diagram""" - nodes = nodes or self.visited - nodes = [n for n in nodes if not n.isnumeric()] - return unite_master_parts( - list(topological_sort(self.graph.subgraph(nodes))) - ) - - @property - def all_ft(self): - """Get restricted FreeTables from all visited nodes.""" - self.cascade() - return [ - self._get_ft(table, with_restr=True) - for table in self.topological_sort() - ] - - def _print_restr(self, leaves=False): - """Print restrictions for each table in visited set.""" - mylist = self.leaves if leaves else self.visited - for table in mylist: - self._log_truncate( - f"{table.split('.')[-1]:>35} {self._get_restr(table)}" - ) - - def get_restr_ft(self, table: Union[int, str]): - """Get restricted FreeTable from graph node. - - Currently used for testing. - - Parameters - ---------- - table : Union[int, str] - Table name or index in visited set - """ - if isinstance(table, int): - table = list(self.visited)[table] - return self._get_ft(table, with_restr=True) - - @property - def as_dict(self) -> List[Dict[str, str]]: - """Return as a list of dictionaries of table_name: restriction""" - self.cascade() - return [ - {"table_name": table, "restriction": self._get_restr(table)} - for table in self.visited - if self._get_restr(table) - ] - - def _bridge_restr( - self, - table1: str, - table2: str, - restr: str, - direction: str = None, - attr_map: dict = None, - primary: bool = None, - aliased: bool = None, - **kwargs, - ): - """Given two tables and a restriction, return restriction for table2. - - Similar to ((table1 & restr) * table2).fetch(*table2.primary_key) - but with the ability to resolve aliases across tables. One table should - be the parent of the other. Replaces previous _child_to_parent. - - Parameters - ---------- - table1 : str - Table name. Restriction always applied to this table. - table2 : str - Table name. Restriction pulled from this table. - restr : str - Restriction to apply to table1. - attr_map : dict, optional - dictionary mapping aliases across tables, as pulled from - DataJoint-assembled graph. Default None. - primary : bool, optional - Is parent in child's primary key? Default True. Also derived from - DataJoint-assembled graph. If True, project only primary key fields - to avoid secondary key collisions. - - Returns - ------- - List[Dict[str, str]] - List of dicts containing primary key fields for restricted table2. - """ - # Direction UP: table1 -> table2, parent -> child - if not all([direction, attr_map, primary, aliased]): - dir_bool, edge = self._get_edge(table1, table2) - direction = "up" if dir_bool else "down" - attr_map = edge.get("attr_map") - primary = edge.get("primary") - aliased = edge.get("aliased") - - ft1 = self._get_ft(table1) - rt1 = ft1 & restr - ft2 = self._get_ft(table2) - - if len(ft1) == 0: - return ["False"] - - adjust = bool(set(attr_map.values()) - set(ft1.heading.names)) - if adjust: - attr_map = self._rev_attrs(attr_map) - - join = rt1.proj(**attr_map) * ft2 - - ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) - - null = None - if self.verbose: - dir = "Up" if direction == "up" else "Dn" - prim = "Pri" if primary else "Sec" - adjust = "Flip" if adjust else "NoFp" - aliaa = "Alias" if aliased else "NoAli" - null = ( - "NULL" - if len(ret) == 0 - else "FULL" if len(ft2) == len(ret) else "part" - ) - strt = f"{to_camel_case(ft1.table_name)}" - endp = f"{to_camel_case(ft2.table_name)}" - self._log_truncate( - f"{dir} {prim} {aliaa} {adjust}: {null} {strt} -> {endp}" - ) - if null and null != "part": - pass - # __import__("pdb").set_trace() - - return ret - - def cascade1(self, table, restriction, direction="up", replace=False): - """Cascade a restriction up the graph, recursively on parents. - - Parameters - ---------- - table : str - table name - restriction : str - restriction to apply - """ - - self._set_restr(table, restriction, replace=replace) - self.visited.add(table) - - next_func = ( - self.graph.parents if direction == "up" else self.graph.children - ) - - for next_table, data in next_func(table).items(): - if next_table.isnumeric(): - next_table, data = next_func(next_table).popitem() - - if ( - next_table in self.visited - or next_table in self.no_visit - or table == next_table - ): - continue - - next_restr = self._bridge_restr( - table1=table, - table2=next_table, - restr=restriction, - direction=direction, - **data, - ) - - self.cascade1( - table=next_table, - restriction=next_restr, - direction=direction, - replace=replace, - ) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 2eba4d0b0..580da9bb4 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -9,6 +9,7 @@ import datajoint as dj from datajoint.condition import make_condition from datajoint.errors import DataJointError +from datajoint.expression import QueryExpression from datajoint.preview import repr_html from datajoint.utils import from_camel_case, to_camel_case from IPython.core.display import HTML @@ -26,15 +27,17 @@ def is_merge_table(table): """Return True if table fields exactly match Merge table.""" + + def trim_def(definition): + return re_sub( + r"\n\s*\n", "\n", re_sub(r"#.*\n", "\n", definition.strip()) + ) + if not isinstance(table, dj.Table): return False if not table.is_declared: if tbl_def := getattr(table, "definition", None): - return MERGE_DEFINITION == re_sub( - r"\n\s*\n", - "\n", - re_sub(r"#.*\n", "\n", tbl_def.strip()), - ) + return trim_def(MERGE_DEFINITION) == trim_def(tbl_def) logger.warning(f"Cannot determine merge table status for {table}") return True return table.primary_key == [ @@ -60,8 +63,8 @@ def __init__(self): if not is_merge_table(self): # Check definition logger.warn( "Merge table with non-default definition\n" - + f"Expected: {MERGE_DEFINITION.strip()}\n" - + f"Actual : {self.definition.strip()}" + + f"Expected:\n{MERGE_DEFINITION.strip()}\n" + + f"Actual :\n{self.definition.strip()}" ) for part in self.parts(as_objects=True): if part.primary_key != self.primary_key: @@ -811,24 +814,82 @@ def super_delete(self, warn=True, *args, **kwargs): # ------------------------------ Restrict by ------------------------------ - # TODO: TEST THIS - # TODO: Allow (Table & restriction).merge_xxx() syntax - def restrict_from_upstream(self, restriction=True, **kwargs): - """Restrict self based on upstream table.""" + def __lshift__(self, restriction) -> QueryExpression: + """Restriction by upstream operator e.g. ``q1 << q2``. + + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_by(restriction, direction="up") + + def __rshift__(self, restriction) -> QueryExpression: + """Restriction by downstream operator e.g. ``q1 >> q2``. + + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_by(restriction, direction="down") + + def restrict_by( + self, + restriction: str = True, + direction: str = "up", + return_graph: bool = False, + verbose: bool = False, + **kwargs, + ) -> QueryExpression: + """Restrict self based on up/downstream table. + + Parameters + ---------- + restriction : str + Restriction to apply to the some table up/downstream of self. + direction : str, optional + Direction to search for valid restriction. Default 'up'. + return_graph : bool, optional + If True, return FindKeyGraph object. Default False, returns + restricted version of present table. + verbose : bool, optional + If True, print verbose output. Default False. + + Returns + ------- + Union[QueryExpression, FindKeyGraph] + Restricted version of present table or FindKeyGraph object. If + return_graph, use all_ft attribute to see all tables in cascade. + """ from spyglass.utils.dj_graph import FindKeyGraph if restriction is True: return self._merge_repr() + try: # Save time if restriction is already valid + ret = self.restrict(restriction) + logger.warning("Restriction valid for this table. Using as is.") + return ret + except DataJointError: + pass # Could avoid try if assert_join_compatible returned a bool + logger.debug("Restriction not valid. Attempting to cascade.") + graph = FindKeyGraph( seed_table=self, restriction=restriction, leaves=self.parts(), + direction=direction, cascade=True, verbose=False, **kwargs, ) + if return_graph: + return graph + self_restrict = [ leaf.fetch(RESERVED_PRIMARY_KEY, as_dict=True) for leaf in graph.leaf_ft diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index c1d4758a6..6eddb9b21 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -18,7 +18,6 @@ from pymysql.err import DataError from spyglass.utils.database_settings import SHARED_MODULES -from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.dj_merge_tables import Merge, is_merge_table @@ -305,6 +304,8 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: with a new restriction. To recompute, add `reload_cache=True` to delete_downstream_merge call. """ + from spyglass.utils.dj_chains import TableChains # noqa F401 + merge_chains = {} for name, merge_table in self._merge_tables.items(): chains = TableChains(self, merge_table, connection=self.connection) @@ -322,7 +323,7 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: ) ) - def _get_chain(self, substring) -> TableChains: + def _get_chain(self, substring): """Return chain from self to merge table with substring in name.""" for name, chain in self._merge_chains.items(): if substring.lower() in name: @@ -471,8 +472,10 @@ def _get_exp_summary(self): return exp_missing + exp_present @cached_property - def _session_connection(self) -> Union[TableChain, bool]: + def _session_connection(self): """Path from Session table to self. False if no connection found.""" + from spyglass.utils.dj_chains import TableChain # noqa F401 + connection = TableChain(parent=self._delete_deps[-1], child=self) return connection if connection.has_link else False @@ -765,7 +768,7 @@ def fetch1(self, *args, log_fetch=True, **kwargs): # ------------------------------ Restrict by ------------------------------ - def __lshift__(self, restriction): + def __lshift__(self, restriction) -> QueryExpression: """Restriction by upstream operator e.g. ``q1 << q2``. Returns @@ -774,9 +777,9 @@ def __lshift__(self, restriction): A restricted copy of the query expression using the nearest upstream table for which the restriction is valid. """ - return self.restrict_from(restriction, direction="up") + return self.restrict_by(restriction, direction="up") - def __rshift__(self, restriction): + def __rshift__(self, restriction) -> QueryExpression: """Restriction by downstream operator e.g. ``q1 >> q2``. Returns @@ -785,17 +788,36 @@ def __rshift__(self, restriction): A restricted copy of the query expression using the nearest upstream table for which the restriction is valid. """ - return self.restrict_from(restriction, direction="down") + return self.restrict_by(restriction, direction="down") - def restrict_from(self, restriction=True, direction="up", **kwargs): - """Restrict self based on upstream table.""" - ret = self.restrict_graph(restriction, direction, **kwargs).leaf_ft[0] - if len(ret) == len(self): - logger.warning("Restriction did not limit table.") - return ret + def restrict_by( + self, + restriction: str = True, + direction: str = "up", + return_graph: bool = False, + verbose: bool = False, + **kwargs, + ) -> QueryExpression: + """Restrict self based on up/downstream table. + + Parameters + ---------- + restriction : str + Restriction to apply to the some table up/downstream of self. + direction : str, optional + Direction to search for valid restriction. Default 'up'. + return_graph : bool, optional + If True, return FindKeyGraph object. Default False, returns + restricted version of present table. + verbose : bool, optional + If True, print verbose output. Default False. - def restrict_graph(self, restriction=True, direction="up", **kwargs): - """Restrict self based on upstream table.""" + Returns + ------- + Union[QueryExpression, FindKeyGraph] + Restricted version of present table or FindKeyGraph object. If + return_graph, use all_ft attribute to see all tables in cascade. + """ from spyglass.utils.dj_graph import FindKeyGraph if restriction is True: @@ -806,7 +828,7 @@ def restrict_graph(self, restriction=True, direction="up", **kwargs): return ret except DataJointError: pass # Could avoid try if assert_join_compatible returned a bool - logger.info("Restriction not valid. Attempting to cascade.") + logger.debug("Restriction not valid. Attempting to cascade.") graph = FindKeyGraph( seed_table=self, @@ -814,7 +836,14 @@ def restrict_graph(self, restriction=True, direction="up", **kwargs): restriction=restriction, direction=direction, cascade=True, - verbose=True, + verbose=verbose, **kwargs, ) - return graph + + if return_graph: + return graph + + ret = graph.leaf_ft[0] + if len(ret) == len(self): + logger.warning("Restriction did not limit table.") + return ret diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index cfc63b5c0..571bb9d26 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -62,13 +62,24 @@ def no_link_chain(Nwbfile): @pytest.fixture(scope="module") def graph_tables(dj_conn): - schema = dj.Schema(context=schema_graph.LOCALS_GRAPH) + lg = schema_graph.LOCALS_GRAPH - for table in schema_graph.LOCALS_GRAPH.values(): + schema = dj.Schema(context=lg) + + for table in lg.values(): schema(table) schema.activate("test_graph", connection=dj_conn) + merge_keys = lg["PkNode"].fetch("KEY", offset=1, as_dict=True) + lg["MergeOutput"].insert(merge_keys, skip_duplicates=True) + merge_child_keys = lg["MergeOutput"].merge_fetch(True, "merge_id", offset=1) + merge_child_inserts = [ + (i, j, k + 10) + for i, j, k in zip(merge_child_keys, range(4), range(10, 15)) + ] + lg["MergeChild"].insert(merge_child_inserts, skip_duplicates=True) + yield schema_graph.LOCALS_GRAPH schema.drop(force=True) diff --git a/tests/utils/schema_graph.py b/tests/utils/schema_graph.py index d646f4f8b..659518ceb 100644 --- a/tests/utils/schema_graph.py +++ b/tests/utils/schema_graph.py @@ -2,7 +2,7 @@ import datajoint as dj -from spyglass.utils import SpyglassMixin +from spyglass.utils import SpyglassMixin, _Merge # Ranges are offset from one another to create unique list of entries for each # table while respecting the foreign key constraints. @@ -147,9 +147,33 @@ class SkAliasNode(SpyglassMixin, dj.Lookup): ] +class MergeOutput(_Merge, SpyglassMixin): + definition = """ + merge_id: uuid + --- + source: varchar(32) + """ + + class PkNode(dj.Part): + definition = """ + -> MergeOutput + --- + -> PkNode + """ + + +class MergeChild(SpyglassMixin, dj.Manual): + definition = """ + -> MergeOutput + merge_child_id: int + --- + merge_child_attr: int + """ + + LOCALS_GRAPH = { k: v for k, v in locals().items() - if inspect_isclass(v) and k != "SpyglassMixin" + if inspect_isclass(v) and k not in ["SpyglassMixin", "_Merge"] } __all__ = list(LOCALS_GRAPH) diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index d86c81b22..33a8c8f1a 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -37,7 +37,7 @@ def test_rg_ft(restr_graph): def test_rg_restr_ft(restr_graph): """Test get restricted free tables.""" - ft = restr_graph.get_restr_ft(1) + ft = restr_graph._get_ft(list(restr_graph.visited)[1], with_restr=True) assert len(ft) == 1, "Unexpected restricted table length." @@ -65,17 +65,40 @@ def test_add_leaf_cascade(restr_graph_new_leaf): def test_add_leaf_restr_ft(restr_graph_new_leaf): restr_graph_new_leaf.cascade() - ft = restr_graph_new_leaf.get_restr_ft("`common_interval`.`interval_list`") + ft = restr_graph_new_leaf._get_ft( + "`common_interval`.`interval_list`", with_restr=True + ) assert len(ft) == 2, "Unexpected restricted table length." +@pytest.fixture(scope="session") +def restr_graph_root(restr_graph, common, lfp_band): + from spyglass.utils.dj_graph import RestrGraph + + yield RestrGraph( + seed_table=common.Session(), + table_name=common.Session.full_table_name, + restriction="True", + direction="down", + cascade=True, + verbose=False, + ) + + +def test_rg_root(restr_graph_root): + assert ( + len(restr_graph_root.all_ft) == 29 + ), "Unexpected number of cascaded tables." + + @pytest.mark.parametrize( "restr, expect_n, msg", [ - ("pk_attr > 16", 4, "pk down, no alias"), - ("sk_attr > 17", 3, "sk down, no alias"), - ("pk_alias_attr > 18", 3, "pk down, pk alias"), - ("sk_alias_attr > 19", 2, "sk down, sk alias"), + ("pk_attr > 16", 4, "pk no alias"), + ("sk_attr > 17", 3, "sk no alias"), + ("pk_alias_attr > 18", 3, "pk pk alias"), + ("sk_alias_attr > 19", 2, "sk sk alias"), + ("merge_child_attr > 21", 2, "merge child down"), ], ) def test_restr_from_upstream(graph_tables, restr, expect_n, msg): @@ -86,10 +109,11 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg): @pytest.mark.parametrize( "table, restr, expect_n, msg", [ - ("PkNode", "parent_attr > 15", 5, "pk up, no alias"), - ("SkNode", "parent_attr > 16", 4, "sk up, no alias"), - ("PkAliasNode", "parent_attr > 17", 2, "pk up, pk alias"), - ("SkAliasNode", "parent_attr > 18", 2, "sk up, sk alias"), + ("PkNode", "parent_attr > 15", 5, "pk no alias"), + ("SkNode", "parent_attr > 16", 4, "sk no alias"), + ("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"), + ("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"), + ("MergeChild", "parent_attr > 18", 2, "merge child"), ], ) def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): From 366540fef79c1b19f8dc9593f8cdbad33fac7e73 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 2 May 2024 18:58:02 -0500 Subject: [PATCH 10/17] WIP: add docs --- docs/src/misc/mixin.md | 24 + src/spyglass/utils/dj_chains.py | 282 ----------- src/spyglass/utils/dj_graph.py | 701 +++++++++++++++++++------- src/spyglass/utils/dj_merge_tables.py | 86 ---- src/spyglass/utils/dj_mixin.py | 58 ++- 5 files changed, 575 insertions(+), 576 deletions(-) delete mode 100644 src/spyglass/utils/dj_chains.py diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 1742d875d..91e30125d 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -75,6 +75,30 @@ Some caveats to this function: 2. This function will raise an error if it attempts to check a table that has not been imported into the current namespace. It is best used for exploring and debugging, not for production code. +3. It's hard to determine the attributes in a mixed dictionary/string + restriction. If you are having trouble, try using a pure string + restriction. +4. The most direct path to your restriction may not be the path took, especially + when using Merge Tables. When the result is empty see the warning about the + path used. To ban nodes from the search, try the following: + +```python +from spyglass.utils.dj_graph import TableChain + +my_chain = TableChain( + child=MyChildTable(), # or parent=MyParentTable() + search_restr="my_str_restriction", + allow_merge=True, # If child is a Merge Table + verbose=True, # Detailed output will show the search history + banned_tables=[UnwantedTable1, UnwantedTable2], +) + +my_chain.endpoint # for the table that meets the restriction +my_chain.all_ft # for all restricted tables in the chain +``` + +When providing a restriction of the parent, use 'up' direction. When providing a +restriction of the child, use 'down' direction. ## Delete Functionality diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py deleted file mode 100644 index 214e0ce4e..000000000 --- a/src/spyglass/utils/dj_chains.py +++ /dev/null @@ -1,282 +0,0 @@ -from collections import OrderedDict -from functools import cached_property -from typing import List, Union - -import datajoint as dj -import networkx as nx -from datajoint.table import Table -from datajoint.utils import to_camel_case - -from spyglass.utils.dj_graph import AbstractGraph -from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, fuzzy_get -from spyglass.utils.dj_merge_tables import is_merge_table - - -class TableChains: - """Class for representing chains from parent to Merge table via parts. - - Functions as a plural version of TableChain, allowing a single `cascade` - call across all chains from parent -> Merge table. - - Attributes - ---------- - parent : Table - Parent or origin of chains. - child : Table - Merge table or destination of chains. - connection : datajoint.Connection, optional - Connection to database used to create FreeTable objects. Defaults to - parent.connection. - part_names : List[str] - List of full table names of child parts. - chains : List[TableChain] - List of TableChain objects for each part in child. - has_link : bool - Cached attribute to store whether parent is linked to child via any of - child parts. False if (a) child is not in parent.descendants or (b) - nx.NetworkXNoPath is raised by nx.shortest_path for all chains. - - Methods - ------- - __init__(parent, child, connection=None) - Initialize TableChains with parent and child tables. - __repr__() - Return full representation of chains. - Multiline parent -> child for each chain. - __len__() - Return number of chains with links. - __getitem__(index: Union[int, str]) - Return TableChain object at index, or use substring of table name. - cascade(restriction: str = None) - Return list of cascade for each chain in self.chains. - """ - - def __init__(self, parent, child, connection=None): - self.parent = parent - self.child = child - self.connection = connection or parent.connection - parts = child.parts(as_objects=True) - self.part_names = [part.full_table_name for part in parts] - self.chains = [TableChain(parent, part) for part in parts] - self.has_link = any([chain.has_link for chain in self.chains]) - - def __repr__(self): - return "\n".join([str(chain) for chain in self.chains]) - - def __len__(self): - return len([c for c in self.chains if c.has_link]) - - @property - def max_len(self): - """Return length of longest chain.""" - return max([len(chain) for chain in self.chains]) - - def __getitem__(self, index: Union[int, str]): - """Return FreeTable object at index.""" - return fuzzy_get(index, self.part_names, self.chains) - - def cascade(self, restriction: str = None, direction: str = "down"): - """Return list of cascades for each chain in self.chains.""" - restriction = restriction or self.parent.restriction or True - cascades = [] - for chain in self.chains: - if joined := chain.cascade(restriction, direction): - cascades.append(joined) - return cascades - - -class TableChain(AbstractGraph): - """Class for representing a chain of tables. - - A chain is a sequence of tables from parent to child identified by - networkx.shortest_path. Parent -> Merge should use TableChains instead to - handle multiple paths to the respective parts of the Merge table. - - Attributes - ---------- - parent : Table - Parent or origin of chain. - child : Table - Child or destination of chain. - _link_symbol : str - Symbol used to represent the link between parent and child. Hardcoded - to " -> ". - has_link : bool - Cached attribute to store whether parent is linked to child. False if - child is not in parent.descendants or nx.NetworkXNoPath is raised by - nx.shortest_path. - link_type : str - 'directed' or 'undirected' based on whether path is found with directed - or undirected graph. None if no path is found. - graph : nx.DiGraph - Directed graph of parent's dependencies from datajoint.connection. - names : List[str] - List of full table names in chain. - path : OrderedDict[str, Dict[str, Union[dj.FreeTable,dict]]] - Dictionary of full table names in chain. Keys are self.names - Values are a dict of free_table (self.objects) and - attr_map (dict of new_name: old_name, self.attr_map). - - Methods - ------- - __str__() - Return string representation of chain: parent -> child. - __repr__() - Return full representation of chain: parent -> {links} -> child. - __len__() - Return number of tables in chain. - __getitem__(index: Union[int, str]) - Return FreeTable object at index, or use substring of table name. - find_path(directed=True) - Returns path OrderedDict of full table names in chain. If directed is - True, uses directed graph. If False, uses undirected graph. Undirected - excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain - valid joins. - cascade(restriction: str = None, direction: str = "up") - Given a restriction at the beginning, return a restricted FreeTable - object at the end of the chain. If direction is 'up', start at the child - and move up to the parent. If direction is 'down', start at the parent. - """ - - def __init__( - self, - parent: Table, - child: Table, - verbose: bool = False, - ): - if is_merge_table(child): - raise TypeError("Child is a merge table. Use TableChains instead.") - - super().__init__(seed_table=parent, verbose=verbose) - _ = self._get_node(child.full_table_name) # ensure child is in graph - - self._link_symbol = " -> " - self.parent = parent - self.child = child - self.link_type = None - self._searched = False - self.undirect_graph = None - - def __str__(self): - """Return string representation of chain: parent -> child.""" - if not self.has_link: - return "No link" - return ( - to_camel_case(self.parent.table_name) - + self._link_symbol - + to_camel_case(self.child.table_name) - ) - - def __repr__(self): - """Return full representation of chain: parent -> {links} -> child.""" - if not self.has_link: - return "No link" - return "Chain: " + self._link_symbol.join(self.path) - - def __len__(self): - """Return number of tables in chain.""" - if not self.has_link: - return 0 - return len(self.path) - - def __getitem__(self, index: Union[int, str]): - return fuzzy_get(index, self.path, self.all_ft) - - @property - def has_link(self) -> bool: - """Return True if parent is linked to child. - - If not searched, search for path. If searched and no link is found, - return False. If searched and link is found, return True. - """ - if not self._searched: - _ = self.path - return self.link_type is not None - - def find_path(self, directed=True) -> OrderedDict: - """Return list of full table names in chain. - - Parameters - ---------- - directed : bool, optional - If True, use directed graph. If False, use undirected graph. - Defaults to True. Undirected permits paths to traverse from merge - part-parent -> merge part -> merge table. Undirected excludes - PERIPHERAL_TABLES like interval_list, nwbfile, etc. - - Returns - ------- - OrderedDict - Dictionary of full table names in chain. Keys are full table names. - Values are free_table (dj.FreeTable representation) and attr_map - (dict of new_name: old_name). Attribute maps on the table upstream - of an alias node that can be used in .proj(). Returns None if no - path is found. - - Ignores numeric table names in paths, which are - 'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph - source code for comments on alias nodes. - """ - source, target = self.parent.full_table_name, self.child.full_table_name - - if not directed: - self.undirect_graph = self.graph.to_undirected() - self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) - - search_graph = self.graph if directed else self.undirect_graph - - try: - path = nx.shortest_path(search_graph, source, target) - except nx.NetworkXNoPath: - return None # No path found, parent func may do undirected search - except nx.NodeNotFound: - self._searched = True # No path found, don't search again - return None - - ignore_nodes = self.graph.nodes - set(path) - self.no_visit.update(ignore_nodes) - - return path - - @cached_property - def path(self) -> list: - """Return list of full table names in chain.""" - if self._searched and not self.has_link: - return None - - path = None - if path := self.find_path(directed=True): - self.link_type = "directed" - elif path := self.find_path(directed=False): - self.link_type = "undirected" - self._searched = True - - return path - - @cached_property - def all_ft(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain. - - Unused. Preserved for future debugging. - """ - if not self.has_link: - return None - return [self._get_ft(table, with_restr=False) for table in self.path] - - def cascade(self, restriction: str = None, direction: str = "up"): - _ = self.path - if not self.has_link: - return None - if direction == "up": - start, end = self.child, self.parent - else: - start, end = self.parent, self.child - if not self.cascaded: - self.cascade1( - table=start.full_table_name, - restriction=restriction, - direction=direction, - replace=True, - ) - self.cascaded = True - return self._get_ft(end.full_table_name, with_restr=True) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 0d90e99aa..b370e7dda 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -6,26 +6,53 @@ from abc import ABC, abstractmethod from collections.abc import KeysView from enum import Enum +from functools import cached_property from itertools import chain as iter_chain from typing import Any, Dict, List, Set, Tuple, Union +import datajoint as dj from datajoint import FreeTable, Table from datajoint.condition import make_condition from datajoint.dependencies import unite_master_parts from datajoint.utils import get_master, to_camel_case -from networkx import all_simple_paths +from networkx import ( + NetworkXNoPath, + NodeNotFound, + all_simple_paths, + shortest_path, +) from networkx.algorithms.dag import topological_sort from tqdm import tqdm from spyglass.utils import logger -from spyglass.utils.dj_helper_fn import PERIPHERAL_TABLES, unique_dicts +from spyglass.utils.dj_helper_fn import ( + PERIPHERAL_TABLES, + fuzzy_get, + unique_dicts, +) +from spyglass.utils.dj_merge_tables import is_merge_table class Direction(Enum): - """Cascade direction enum.""" + """Cascade direction enum. Calling Up returns True. Inverting flips.""" UP = "up" DOWN = "down" + NONE = None + + def __str__(self): + return self.value + + def __invert__(self) -> "Direction": + """Invert the direction.""" + if self.value is None: + logger.warning("Inverting NONE direction") + return Direction.NONE + return Direction.UP if self.value == "down" else Direction.DOWN + + def __bool__(self) -> bool: + """Return True if direction is UP.""" + return self.value is not None class AbstractGraph(ABC): @@ -33,11 +60,11 @@ class AbstractGraph(ABC): Inherited by... - RestrGraph: Cascade restriction(s) through a graph - - FindKeyGraph: Iherits from RestrGraph. Cascades through the graph to - find where a restriction works, and cascades back across visited - nodes. - TableChain: Takes parent and child nodes, finds the shortest path, - and applies a restriction across the path. + and applies a restriction across the path. If either parent or child + is a merge table, use TableChains instead. If either parent or child + are not provided, search_restr is required to find the path to the + missing table. Methods ------- @@ -61,9 +88,16 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): verbose : bool, optional Whether to print verbose output. Default False """ + self.seed_table = seed_table self.connection = seed_table.connection + + # Undirected graph may not be needed, but adding FT to the graph + # prevents `to_undirected` from working. If using undirected, remove + # PERIPHERAL_TABLES from the graph. self.graph = seed_table.connection.dependencies self.graph.load() + self.undirect_graph = self.graph.to_undirected() + self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) self.verbose = verbose self.leaves = set() @@ -72,11 +106,15 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): self.no_visit = set() self.cascaded = False + # --------------------------- Abstract Methods --------------------------- + @abstractmethod def cascade(self): """Cascade restrictions through graph.""" raise NotImplementedError("Child class mut implement `cascade` method") + # ---------------------------- Logging Helpers ---------------------------- + def _log_truncate(self, log_str: str, max_len: int = 80): """Truncate log lines to max_len and print if verbose.""" if not self.verbose: @@ -85,9 +123,32 @@ def _log_truncate(self, log_str: str, max_len: int = 80): log_str[:max_len] + "..." if len(log_str) > max_len else log_str ) - def _ensure_name(self, table: Union[str, Table]) -> str: + def _camel(self, table): + """Convert table name(s) to camel case.""" + if isinstance(table, KeysView): + table = list(table) + if not isinstance(table, list): + table = [table] + ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] + return ret[0] if len(ret) == 1 else ret + + def _print_restr(self): + """Print restrictions for debugging.""" + for table in self.visited: + if restr := self._get_restr(table): + logger.info(f"{table}: {restr}") + + # ------------------------------ Graph Nodes ------------------------------ + + def _ensure_name(self, table: Union[str, Table] = None) -> str: """Ensure table is a string.""" - return table if isinstance(table, str) else table.full_table_name + if table is None: + return None + if isinstance(table, str): + return table + if isinstance(table, list): + return [self._ensure_name(t) for t in table] + return getattr(table, "full_table_name", None) def _get_node(self, table: Union[str, Table]): """Get node from graph.""" @@ -107,7 +168,8 @@ def _set_node(self, table, attr: str = "ft", value: Any = None): def _get_edge(self, child: str, parent: str) -> Tuple[bool, Dict[str, str]]: """Get edge data between child and parent. - Used as a fallback for _bridge_restr. Not required in typical use. + Used as a fallback for _bridge_restr. Required for Maser/Part links to + temporarily flip direction. Returns ------- @@ -174,31 +236,16 @@ def _get_ft(self, table, with_restr=False): return ft & restr - @property - def all_ft(self): - """Get restricted FreeTables from all visited nodes. - - Topological sort logic adopted from datajoint.diagram. - """ - self.cascade() - nodes = [n for n in self.visited if not n.isnumeric()] - sorted_nodes = unite_master_parts( - list(topological_sort(self.graph.subgraph(nodes))) - ) - all_ft = [ - self._get_ft(table, with_restr=True) for table in sorted_nodes - ] - return [ft for ft in all_ft if len(ft) > 0] + def _and_parts(self, table): + """Return table, its master and parts.""" + ret = [table] + if master := get_master(table): + ret.append(master) + if parts := self._get_ft(table).parts(): + ret.extend(parts) + return ret - @property - def as_dict(self) -> List[Dict[str, str]]: - """Return as a list of dictionaries of table_name: restriction""" - self.cascade() - return [ - {"table_name": table, "restriction": self._get_restr(table)} - for table in self.visited - if self._get_restr(table) - ] + # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( self, @@ -248,41 +295,70 @@ def _bridge_restr( if len(ft1) == 0: return ["False"] - if rev_attr := bool(set(attr_map.values()) - set(ft1.heading.names)): + if bool(set(attr_map.values()) - set(ft1.heading.names)): attr_map = {v: k for k, v in attr_map.items()} # reverse join = ft1.proj(**attr_map) * ft2 ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) if self.verbose: # For debugging. Not required for typical use. - partial = ( - "NULL" + result = ( + "EMPTY" if len(ret) == 0 - else "FULL" if len(ft2) == len(ret) else "part" - ) - flipped = "Fliped" if rev_attr else "NoFlip" - dir = "Up" if direction == "up" else "Dn" - strt = f"{to_camel_case(ft1.table_name)}" - endp = f"{to_camel_case(ft2.table_name)}" - self._log_truncate( - f"{partial} {dir} {flipped}: {strt} -> {endp}, {len(ret)}" + else "FULL" if len(ft2) == len(ret) else "partial" ) + path = f"{self._camel(table1)} -> {self._camel(table2)}" + self._log_truncate(f"Bridge Link: {path}: result {result}") return ret - def _camel(self, table): - if isinstance(table, KeysView): - table = list(table) - if not isinstance(table, list): - table = [table] - ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] - return ret[0] if len(ret) == 1 else ret + def _get_next_tables(self, table: str, direction: Direction) -> Tuple: + """Get next tables/func based on direction. + + Used in cascade1 and cascade1_search to add master and parts. Direction + is intentionally omitted to force _get_edge to determine the edge for + this gap before resuming desired direction. Nextfunc is used to get + relevant parent/child tables after aliast node. + + Parameters + ---------- + table : str + Table name + direction : Direction + Direction to cascade + + Returns + ------- + Tuple[Dict[str, Dict[str, str]], Callable + Tuple of next tables and next function to get parent/child tables. + """ + G = self.graph + dir_dict = {"direction": direction} + + bonus = {} + direction = Direction(direction) + if direction == Direction.UP: + next_func = G.parents + bonus.update({part: {} for part in self._get_ft(table).parts()}) + elif direction == Direction.DOWN: + next_func = G.children + if (master_name := get_master(table)) != "": + bonus = {master_name: {}} + else: + raise ValueError(f"Invalid direction: {direction}") + + next_tables = { + k: {**v, **dir_dict} for k, v in next_func(table).items() + } + next_tables.update(bonus) + + return next_tables, next_func def cascade1( self, table: str, restriction: str, - direction: Direction = "up", + direction: Direction = Direction.UP, replace=False, count=0, **kwargs, @@ -306,36 +382,10 @@ def cascade1( self._set_restr(table, restriction, replace=replace) self.visited.add(table) - G = self.graph - next_func = G.parents if direction == "up" else G.children - dir_dict = {"direction": direction} - dir_name = "Parents" if direction == "up" else "Children" - - # Master/Parts added will go in opposite direction for one link. - # Direction is intentionally not passed to _bridge_restr in this case. - if direction == "up": - next_tables = { - k: {**v, **dir_dict} for k, v in next_func(table).items() - } - next_tables.update( - {part: {} for part in self._get_ft(table).parts()} - ) - else: - next_tables = { - k: {**v, **dir_dict} for k, v in next_func(table).items() - } - if (master_name := get_master(table)) != "": - next_tables[master_name] = {} + next_tables, next_func = self._get_next_tables(table, direction) - log_dict = { - "Table ": self._camel(table), - f"{dir_name}": self._camel(next_func(table).keys()), - "Parts ": self._camel(self._get_ft(table).parts()), - "Master ": self._camel(get_master(table)), - } - logger.info( - f"Cascade1: {count}\n\t\t\t " - + "\n\t\t\t ".join(f"{k}: {v}" for k, v in log_dict.items()) + self._log_truncate( + f"Checking {count:>2}: {self._camel(next_tables.keys())}" ) for next_table, data in next_tables.items(): if next_table.isnumeric(): # Skip alias nodes @@ -346,14 +396,12 @@ def cascade1( or next_table in self.no_visit # Subclasses can set this or table == next_table ): - path = f"{self._camel(table)} -> {self._camel(next_table)}" - if next_table in self.visited: - self._log_truncate(f"SkipVist: {path}") - if next_table in self.no_visit: - self._log_truncate(f"NoVisit : {path}") - if table == next_table: - self._log_truncate(f"Self : {path}") - + reason = ( + "Already saw" + if next_table in self.visited + else "Banned Tbl " + ) + self._log_truncate(f"{reason}: {self._camel(next_table)}") continue next_restr = self._bridge_restr( @@ -371,6 +419,34 @@ def cascade1( count=count + 1, ) + # ---------------------------- Graph Properties ---------------------------- + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes. + + Topological sort logic adopted from datajoint.diagram. + """ + self.cascade() + nodes = [n for n in self.visited if not n.isnumeric()] + sorted_nodes = unite_master_parts( + list(topological_sort(self.graph.subgraph(nodes))) + ) + all_ft = [ + self._get_ft(table, with_restr=True) for table in sorted_nodes + ] + return [ft for ft in all_ft if len(ft) > 0] + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.visited + if self._get_restr(table) + ] + class RestrGraph(AbstractGraph): def __init__( @@ -600,54 +676,242 @@ def file_paths(self) -> List[str]: ] -class FindKeyGraph(RestrGraph): +class TableChains: + """Class for representing chains from parent to Merge table via parts. + + Functions as a plural version of TableChain, allowing a single `cascade` + call across all chains from parent -> Merge table. + + Attributes + ---------- + parent : Table + Parent or origin of chains. + child : Table + Merge table or destination of chains. + connection : datajoint.Connection, optional + Connection to database used to create FreeTable objects. Defaults to + parent.connection. + part_names : List[str] + List of full table names of child parts. + chains : List[TableChain] + List of TableChain objects for each part in child. + has_link : bool + Cached attribute to store whether parent is linked to child via any of + child parts. False if (a) child is not in parent.descendants or (b) + nx.NetworkXNoPath is raised by nx.shortest_path for all chains. + + Methods + ------- + __init__(parent, child, connection=None) + Initialize TableChains with parent and child tables. + __repr__() + Return full representation of chains. + Multiline parent -> child for each chain. + __len__() + Return number of chains with links. + __getitem__(index: Union[int, str]) + Return TableChain object at index, or use substring of table name. + cascade(restriction: str = None) + Return list of cascade for each chain in self.chains. + """ + + def __init__(self, parent, child, direction=Direction.DOWN): + self.parent = parent + self.child = child + self.connection = parent.connection + self.part_names = child.parts() + self.chains = [ + TableChain(parent, part, direction=direction) + for part in self.part_names + ] + self.has_link = any([chain.has_link for chain in self.chains]) + + # --------------------------- Dunder Properties --------------------------- + + def __repr__(self): + l_str = ",\n\t".join([str(c) for c in self.chains]) + "\n" + return f"{self.__class__.__name__}(\n\t{l_str})" + + def __len__(self): + return len([c for c in self.chains if c.has_link]) + + def __getitem__(self, index: Union[int, str]): + """Return FreeTable object at index.""" + return fuzzy_get(index, self.part_names, self.chains) + + # ---------------------------- Public Properties -------------------------- + + @property + def max_len(self): + """Return length of longest chain.""" + return max([len(chain) for chain in self.chains]) + + # ------------------------------ Graph Traversal -------------------------- + + def cascade( + self, restriction: str = None, direction: Direction = Direction.DOWN + ): + """Return list of cascades for each chain in self.chains.""" + restriction = restriction or self.parent.restriction or True + cascades = [] + for chain in self.chains: + if joined := chain.cascade(restriction, direction): + cascades.append(joined) + return cascades + + +class TableChain(RestrGraph): + """Class for representing a chain of tables. + + A chain is a sequence of tables from parent to child identified by + networkx.shortest_path. Parent -> Merge should use TableChains instead to + handle multiple paths to the respective parts of the Merge table. + + Attributes + ---------- + parent : str + Parent or origin of chain. + child : str + Child or destination of chain. + has_link : bool + Cached attribute to store whether parent is linked to child. + path : List[str] + Names of tables along the path from parent to child. + all_ft : List[dj.FreeTable] + List of FreeTable objects for each table in chain with restriction + applied. + + Methods + ------- + find_path(directed=True) + Returns path OrderedDict of full table names in chain. If directed is + True, uses directed graph. If False, uses undirected graph. Undirected + excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain + valid joins. + cascade(restriction: str = None, direction: str = "up") + Given a restriction at the beginning, return a restricted FreeTable + object at the end of the chain. If direction is 'up', start at the child + and move up to the parent. If direction is 'down', start at the parent. + """ + def __init__( self, - seed_table: Table, - table_name: str = None, - restriction: str = None, - leaves: List[Dict[str, str]] = None, - direction: Direction = "up", + parent: Table = None, + child: Table = None, + direction: Direction = Direction.NONE, + search_restr: str = None, cascade: bool = False, verbose: bool = False, + allow_merge: bool = False, + banned_tables: List[str] = None, **kwargs, ): - """Graph to restrict leaf by upstream keys. + if not allow_merge and child and is_merge_table(child): + raise TypeError("Child is a merge table. Use TableChains instead.") - Parameters - ---------- - seed_table : Table - Table to use to establish connection and graph - table_name : str, optional - Table name of single leaf, default seed_table.full_table_name - restriction : str, optional - Restriction to apply to leaf. default None, True - verbose : bool, optional - Whether to print verbose output. Default False - """ + self.parent = self._ensure_name(parent) + self.child = self._ensure_name(child) - super().__init__(seed_table, verbose=verbose) + if not self.parent and not self.child: + raise ValueError("Parent or child table required.") + if not search_restr and not (self.parent and self.child): + raise ValueError("Search restriction required to find path.") - self.direction = direction - self.searched = set() - self.found = False - - if restriction and table_name: - self._set_find_restr(table_name, restriction) - self.add_leaf(table_name, True, cascade=False, direction=direction) - self.add_leaves( - leaves, - default_restriction=restriction, - cascade=False, - direction=direction, - ) + super().__init__(seed_table=parent or child, verbose=verbose) self.no_visit.update(PERIPHERAL_TABLES) - - if cascade and restriction: + self.no_visit.update(self._ensure_name(banned_tables) or []) + self.searched_tables = set() + self.found_restr = False + self.link_type = None + self.searched_path = False + self._link_symbol = " -> " + + self.search_restr = search_restr + self.direction = Direction(direction) + + self.leaf = None + if search_restr and not parent: + self.direction = Direction.UP + self.leaf = self.child + if search_restr and not child: + self.direction = Direction.DOWN + self.leaf = self.parent + + if self.leaf: + self._set_find_restr(self.leaf, search_restr) + self.add_leaf(self.leaf, True, cascade=False, direction=direction) + + if cascade and search_restr: + self.cascade_search() self.cascade() self.cascaded = True + # --------------------------- Dunder Properties --------------------------- + + def __str__(self): + """Return string representation of chain: parent -> child.""" + if not self.has_link: + return "No link" + return ( + self._camel(self.parent) + + self._link_symbol + + self._camel(self.child) + ) + + def __repr__(self): + """Return full representation of chain: parent -> {links} -> child.""" + if not self.has_link: + return "No link" + return "Chain: " + self.path_str + + def __len__(self): + """Return number of tables in chain.""" + if not self.has_link: + return 0 + return len(self.path) + + def __getitem__(self, index: Union[int, str]): + return fuzzy_get(index, self.path, self.all_ft) + + # ---------------------------- Public Properties -------------------------- + + @property + def has_link(self) -> bool: + """Return True if parent is linked to child. + + If not searched, search for path. If searched and no link is found, + return False. If searched and link is found, return True. + """ + if not self.searched_path: + _ = self.path + return self.link_type is not None + + @cached_property + def all_ft(self) -> List[dj.FreeTable]: + """Return list of FreeTable objects for each table in chain. + + Unused. Preserved for future debugging. + """ + if not self.has_link: + return None + return [ + self._get_ft(table, with_restr=False) + for table in self.path + if not table.isnumeric() + ] + + @property + def path_str(self) -> str: + return self._link_symbol.join([self._camel(t) for t in self.path]) + + @property + def endpoint(self) -> str: + """Return endpoint of chain.""" + return self.leaf_ft[0] + + # ------------------------------ Graph Nodes ------------------------------ + def _set_find_restr(self, table_name, restriction): """Set restr to look for from leaf node.""" if isinstance(restriction, dict): @@ -665,114 +929,161 @@ def _get_find_restr(self, table) -> Tuple[str, Set[str]]: node = self._get_node(table) return node.get("find_restr", False), node.get("restr_attrs", set()) - def add_leaves( - self, - leaves=None, - default_restriction=None, - cascade=False, - direction=None, - ): - leaves = self._process_leaves( - leaves=leaves, default_restriction=default_restriction - ) - for leaf in leaves: # Multiple leaves - self._set_find_restr(**leaf) - self.add_leaf( - leaf["table_name"], - True, - cascade=False, - direction=direction, - ) + # ---------------------------- Graph Traversal ---------------------------- - def cascade(self, direction=None, show_progress=None) -> None: - direction = direction or self.direction + def cascade_search(self) -> None: if self.cascaded: return - for table in self.leaves: - restriction, restr_attrs = self._get_find_restr(table) - self.cascade1_search( - table=table, - restriction=restriction, - restr_attrs=restr_attrs, - replace=True, + restriction, restr_attrs = self._get_find_restr(self.leaf) + self.cascade1_search( + table=self.leaf, + restriction=restriction, + restr_attrs=restr_attrs, + replace=True, + ) + if not self.found_restr: + searched = ( + "parents" if self.direction == Direction.UP else "children" ) - self.cascaded = True - if not self.found: - searched = "parents" if direction == "up" else "children" logger.warning( f"Restriction could not be applied to any {searched}.\n\t" + f"From: {self.leaves}\n\t" + f"Restr: {restriction}" ) - def _ban_unsearched(self): - """After found match, ignore others for cascade back to leaf.""" - all_tables = set([n for n in self.graph.nodes]) - unsearched = all_tables - self.searched - camel_searched = self._camel(list(self.searched)) - logger.info(f"Searched: {camel_searched}") - self.no_visit.update(unsearched) + def _set_found_vars(self, table): + """Set found_restr and searched_tables.""" + self._set_restr(table, self.search_restr, replace=True) + self.found_restr = True + self.searched_tables.update(set(self._and_parts(table))) + + if self.direction == Direction.UP: + self.parent = table + elif self.direction == Direction.DOWN: + self.child = table + + self.direction = ~self.direction + _ = self.path # Reset path def cascade1_search( self, table: str = None, restriction: str = True, restr_attrs: Set[str] = None, - direction: Direction = None, replace: bool = True, limit: int = 100, + **kwargs, ): - if self.found or not table or limit < 1 or table in self.searched: + if ( + self.found_restr + or not table + or limit < 1 + or table in self.searched_tables + ): return - self.searched.add(table) + self.searched_tables.add(table) + next_tables, next_func = self._get_next_tables(table, self.direction) - direction = direction or self.direction - next_func = ( - self.graph.parents if direction == "up" else self.graph.children - ) - - next_searches = set() - for next_table, data in next_func(table).items(): - self._log_truncate( - f"Search: {self._camel(table)} -> {self._camel(next_table)}" - ) + for next_table, data in next_tables.items(): if next_table.isnumeric(): next_table, data = next_func(next_table).popitem() + self._log_truncate( + f"Search Link: {self._camel(table)} -> {self._camel(next_table)}" + ) if next_table in self.no_visit or table == next_table: + reason = "Already Saw" if next_table == table else "Banned Tbl " + self._log_truncate(f"{reason}: {self._camel(next_table)}") continue next_ft = self._get_ft(next_table) if restr_attrs.issubset(set(next_ft.heading.names)): - self.searched.add(next_table) - # self.searched.add(get_master(next_table)) - self.searched.update(next_ft.parts()) - self.found = True - self._ban_unsearched() - self.cascade1( - table=next_table, - restriction=restriction, - direction="down" if direction == "up" else "up", - replace=replace, - **data, - ) + self._set_found_vars(next_table) return - next_searches.update( - set([*next_ft.parts(), get_master(next_table), next_table]) - ) - - for next_table in next_searches: - if not next_table: - continue # Skip None from get_master self.cascade1_search( table=next_table, restriction=restriction, restr_attrs=restr_attrs, - direction=direction, replace=replace, limit=limit - 1, + **data, ) - if self.found: + if self.found_restr: return + + # ------------------------------ Path Finding ------------------------------ + + def find_path(self, directed=True) -> List[str]: + """Return list of full table names in chain. + + Parameters + ---------- + directed : bool, optional + If True, use directed graph. If False, use undirected graph. + Defaults to True. Undirected permits paths to traverse from merge + part-parent -> merge part -> merge table. Undirected excludes + PERIPHERAL_TABLES like interval_list, nwbfile, etc. + + Returns + ------- + List[str] + List of names in the path. + """ + source, target = self.parent, self.child + search_graph = self.graph if directed else self.undirect_graph + search_graph.remove_nodes_from(self.no_visit) + + try: + path = shortest_path(search_graph, source, target) + except NetworkXNoPath: + return None # No path found, parent func may do undirected search + except NodeNotFound: + self.searched_path = True # No path found, don't search again + return None + + ignore_nodes = self.graph.nodes - set(path) + self.no_visit.update(ignore_nodes) + + return path + + @cached_property + def path(self) -> list: + """Return list of full table names in chain.""" + if self.searched_path and not self.has_link: + return None + + path = None + if path := self.find_path(directed=True): + self.link_type = "directed" + elif path := self.find_path(directed=False): + self.link_type = "undirected" + self.searched_path = True + + return path + + def cascade(self, restriction: str = None, direction: Direction = None): + if not self.has_link: + return + + _ = self.path + + direction = direction or self.direction + if direction == Direction.UP: + start, end = self.child, self.parent + else: + start, end = self.parent, self.child + + self.cascade1( + table=start, + restriction=restriction or self._get_restr(start), + direction=direction, + replace=True, + ) + + return self._get_ft(end, with_restr=True) + + def restrict_by(self, *args, **kwargs) -> None: + """Cascade passthrough.""" + return self.cascade(*args, **kwargs) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 580da9bb4..198270b27 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -9,7 +9,6 @@ import datajoint as dj from datajoint.condition import make_condition from datajoint.errors import DataJointError -from datajoint.expression import QueryExpression from datajoint.preview import repr_html from datajoint.utils import from_camel_case, to_camel_case from IPython.core.display import HTML @@ -812,91 +811,6 @@ def super_delete(self, warn=True, *args, **kwargs): self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) - # ------------------------------ Restrict by ------------------------------ - - def __lshift__(self, restriction) -> QueryExpression: - """Restriction by upstream operator e.g. ``q1 << q2``. - - Returns - ------- - QueryExpression - A restricted copy of the query expression using the nearest upstream - table for which the restriction is valid. - """ - return self.restrict_by(restriction, direction="up") - - def __rshift__(self, restriction) -> QueryExpression: - """Restriction by downstream operator e.g. ``q1 >> q2``. - - Returns - ------- - QueryExpression - A restricted copy of the query expression using the nearest upstream - table for which the restriction is valid. - """ - return self.restrict_by(restriction, direction="down") - - def restrict_by( - self, - restriction: str = True, - direction: str = "up", - return_graph: bool = False, - verbose: bool = False, - **kwargs, - ) -> QueryExpression: - """Restrict self based on up/downstream table. - - Parameters - ---------- - restriction : str - Restriction to apply to the some table up/downstream of self. - direction : str, optional - Direction to search for valid restriction. Default 'up'. - return_graph : bool, optional - If True, return FindKeyGraph object. Default False, returns - restricted version of present table. - verbose : bool, optional - If True, print verbose output. Default False. - - Returns - ------- - Union[QueryExpression, FindKeyGraph] - Restricted version of present table or FindKeyGraph object. If - return_graph, use all_ft attribute to see all tables in cascade. - """ - from spyglass.utils.dj_graph import FindKeyGraph - - if restriction is True: - return self._merge_repr() - - try: # Save time if restriction is already valid - ret = self.restrict(restriction) - logger.warning("Restriction valid for this table. Using as is.") - return ret - except DataJointError: - pass # Could avoid try if assert_join_compatible returned a bool - logger.debug("Restriction not valid. Attempting to cascade.") - - graph = FindKeyGraph( - seed_table=self, - restriction=restriction, - leaves=self.parts(), - direction=direction, - cascade=True, - verbose=False, - **kwargs, - ) - - if return_graph: - return graph - - self_restrict = [ - leaf.fetch(RESERVED_PRIMARY_KEY, as_dict=True) - for leaf in graph.leaf_ft - ] - - return self & self_restrict - _Merge = Merge diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 6eddb9b21..758b39a08 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -304,11 +304,11 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: with a new restriction. To recompute, add `reload_cache=True` to delete_downstream_merge call. """ - from spyglass.utils.dj_chains import TableChains # noqa F401 + from spyglass.utils.dj_graph import TableChains # noqa F401 merge_chains = {} for name, merge_table in self._merge_tables.items(): - chains = TableChains(self, merge_table, connection=self.connection) + chains = TableChains(self, merge_table) if len(chains): merge_chains[name] = chains @@ -474,7 +474,7 @@ def _get_exp_summary(self): @cached_property def _session_connection(self): """Path from Session table to self. False if no connection found.""" - from spyglass.utils.dj_chains import TableChain # noqa F401 + from spyglass.utils.dj_graph import TableChain # noqa F401 connection = TableChain(parent=self._delete_deps[-1], child=self) return connection if connection.has_link else False @@ -800,6 +800,26 @@ def restrict_by( ) -> QueryExpression: """Restrict self based on up/downstream table. + If fails to restrict table, the shortest path may not have been correct. + Check the output to see where the restriction was applied. + Try the following: + + >>> from spyglass.utils.dj_graph import TableChain + >>> + >>> my_chain = TableChain( + >>> child=MyChildTable(), # or parent=MyParentTable + >>> search_restr=restriction, + >>> allow_merge=True, # If child is Merge + >>> verbose=True, # Detailed output + >>> banned_tables=[UnwantedTable1, UnwantedTable2], + >>> ) + >>> + >>> my_chain.endpoint # for the table that meets the restriction + >>> my_chain.all_ft # for all restricted tables in the chain + + When providing a restriction of the parent, use 'up' direction. When + providing a restriction of the child, use 'down' direction. + Parameters ---------- restriction : str @@ -818,23 +838,31 @@ def restrict_by( Restricted version of present table or FindKeyGraph object. If return_graph, use all_ft attribute to see all tables in cascade. """ - from spyglass.utils.dj_graph import FindKeyGraph + from spyglass.utils.dj_graph import TableChain # noqa: F401 if restriction is True: return self - try: # Save time if restriction is already valid - ret = self.restrict(restriction) + try: + ret = self.restrict(restriction) # Save time trying first logger.warning("Restriction valid for this table. Using as is.") return ret except DataJointError: - pass # Could avoid try if assert_join_compatible returned a bool + pass # Could avoid try/except if assert_join_compatible return bool logger.debug("Restriction not valid. Attempting to cascade.") - graph = FindKeyGraph( - seed_table=self, - table_name=self.full_table_name, - restriction=restriction, + if direction == "up": + parent, child = None, self + elif direction == "down": + parent, child = self, None + else: + raise ValueError("Direction must be 'up' or 'down'.") + + graph = TableChain( + parent=parent, + child=child, direction=direction, + search_restr=restriction, + allow_merge=True, cascade=True, verbose=verbose, **kwargs, @@ -844,6 +872,10 @@ def restrict_by( return graph ret = graph.leaf_ft[0] - if len(ret) == len(self): - logger.warning("Restriction did not limit table.") + if len(ret) == len(self) or len(ret) == 0: + logger.warning( + f"Failed to restrict with path: {graph.path_str}\n" + + "See help(YourTable.restrict_by)" + ) + return ret From a1518e4081ce6903caf5c5fb324ccd93fc9ae019 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 2 May 2024 20:30:20 -0500 Subject: [PATCH 11/17] WIP: Revise tests --- CHANGELOG.md | 12 ++++++++++++ pyproject.toml | 6 +++--- src/spyglass/spikesorting/imported.py | 3 ++- src/spyglass/utils/dj_graph.py | 2 +- src/spyglass/utils/dj_merge_tables.py | 2 +- tests/utils/conftest.py | 2 +- tests/utils/test_chains.py | 2 +- 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6eda918e..4d52dcb44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Change Log +## \[Unreleased\] + +### Release Nodes + + + +### Infrastructure + +- Add long-distance restrictions via `<<` and `>>` operators. #943 + ## [0.5.2] (April 22, 2024) ### Infrastructure diff --git a/pyproject.toml b/pyproject.toml index 45617385b..ffb8d0df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,10 +121,10 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ "-sv", - "--sw", # stepwise: resume with next test after failure - "--pdb", # drop into debugger on failure + # "--sw", # stepwise: resume with next test after failure + # "--pdb", # drop into debugger on failure "-p no:warnings", - "--no-teardown", # don't teardown the database after tests + # "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index ccb24edd2..ca1bdc9d0 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -51,8 +51,9 @@ def make(self, key): self.insert1(key, skip_duplicates=True) + part_name = SpikeSortingOutput._part_name(self.table_name) SpikeSortingOutput._merge_insert( - [orig_key], part_name=self.camel_name, skip_duplicates=True + [orig_key], part_name=part_name, skip_duplicates=True ) @classmethod diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index b370e7dda..f99df10a7 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -51,7 +51,7 @@ def __invert__(self) -> "Direction": return Direction.UP if self.value == "down" else Direction.DOWN def __bool__(self) -> bool: - """Return True if direction is UP.""" + """Return True if direction is not None.""" return self.value is not None diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 198270b27..646e34171 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -75,7 +75,7 @@ def __init__(self): self._source_class_dict = {} @staticmethod - def _part_name(part): + def _part_name(part=None): """Return the CamelCase name of a part table""" if not isinstance(part, str): part = part.table_name diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 571bb9d26..4c6209de2 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -55,7 +55,7 @@ def chain(chains): def no_link_chain(Nwbfile): """Return example TableChain object with no link.""" from spyglass.common.common_usage import InsertError - from spyglass.utils.dj_chains import TableChain + from spyglass.utils.dj_graph import TableChain yield TableChain(Nwbfile, InsertError()) diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index cb8bbccc4..3cc085471 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -4,7 +4,7 @@ @pytest.fixture(scope="session") def TableChain(): - from spyglass.utils.dj_chains import TableChain + from spyglass.utils.dj_graph import TableChain return TableChain From 263004d99243c96554189d0e5116283af95a6cbb Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 May 2024 13:47:58 -0500 Subject: [PATCH 12/17] WIP: Add way to ban item from search --- docs/src/misc/mixin.md | 21 +- pyproject.toml | 4 +- src/spyglass/utils/dj_graph.py | 59 ++++-- src/spyglass/utils/dj_merge_tables.py | 9 +- src/spyglass/utils/dj_mixin.py | 54 +++-- tests/common/test_device.py | 2 +- tests/conftest.py | 274 +++++++++++++++++++++++++- tests/lfp/conftest.py | 119 ----------- tests/linearization/conftest.py | 142 ------------- tests/linearization/test_lin.py | 2 +- tests/utils/test_chains.py | 19 +- tests/utils/test_graph.py | 17 +- tests/utils/test_mixin.py | 22 ++- 13 files changed, 403 insertions(+), 341 deletions(-) delete mode 100644 tests/linearization/conftest.py diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 91e30125d..6b3884551 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -80,21 +80,16 @@ Some caveats to this function: restriction. 4. The most direct path to your restriction may not be the path took, especially when using Merge Tables. When the result is empty see the warning about the - path used. To ban nodes from the search, try the following: + path used. Then, ban tables from the search to force a different path. ```python -from spyglass.utils.dj_graph import TableChain - -my_chain = TableChain( - child=MyChildTable(), # or parent=MyParentTable() - search_restr="my_str_restriction", - allow_merge=True, # If child is a Merge Table - verbose=True, # Detailed output will show the search history - banned_tables=[UnwantedTable1, UnwantedTable2], -) - -my_chain.endpoint # for the table that meets the restriction -my_chain.all_ft # for all restricted tables in the chain +my_table = MyTable() # must be instantced +my_table.ban_search_table(UnwantedTable1) +my_table.ban_search_table([UnwantedTable2, UnwantedTable3]) +my_table.unban_search_table(UnwantedTable3) +my_table.see_banned_tables() + +my_table << my_restriction ``` When providing a restriction of the parent, use 'up' direction. When providing a diff --git a/pyproject.toml b/pyproject.toml index ffb8d0df6..f9ed16ab6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,8 +121,8 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ "-sv", - # "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure + "--sw", # stepwise: resume with next test after failure + "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index f99df10a7..d3f8fab04 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -96,8 +96,6 @@ def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): # PERIPHERAL_TABLES from the graph. self.graph = seed_table.connection.dependencies self.graph.load() - self.undirect_graph = self.graph.to_undirected() - self.undirect_graph.remove_nodes_from(PERIPHERAL_TABLES) self.verbose = verbose self.leaves = set() @@ -497,16 +495,29 @@ def __init__( if cascade: self.cascade(direction=direction) + # --------------------------- Dunder Properties --------------------------- + def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" processed = "Cascaded" if self.cascaded else "Uncascaded" return f"{processed} {self.__class__.__name__}(\n\t{l_str})" + def __getitem__(self, index: Union[int, str]): + all_ft_names = [t.full_table_name for t in self.all_ft] + return fuzzy_get(index, all_ft_names, self.all_ft) + + def __len__(self): + return len(self.all_ft) + + # ---------------------------- Public Properties -------------------------- + @property def leaf_ft(self): """Get restricted FreeTables from graph leaves.""" return [self._get_ft(table, with_restr=True) for table in self.leaves] + # ------------------------------- Add Nodes ------------------------------- + def add_leaf( self, table_name=None, restriction=True, cascade=False, direction="up" ) -> None: @@ -593,6 +604,8 @@ def add_leaves( if cascade: self.cascade() + # ------------------------------ Graph Traversal -------------------------- + def cascade(self, show_progress=None, direction="up") -> None: """Cascade all restrictions up the graph. @@ -615,7 +628,7 @@ def cascade(self, show_progress=None, direction="up") -> None: restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") self.cascade1(table, restr, direction=direction) - if not self.visited == self.to_visit: + if self.to_visit - self.visited: raise RuntimeError( "Cascade: FAIL - incomplete cascade. Please post issue." ) @@ -715,13 +728,13 @@ class TableChains: Return list of cascade for each chain in self.chains. """ - def __init__(self, parent, child, direction=Direction.DOWN): + def __init__(self, parent, child, direction=Direction.DOWN, verbose=False): self.parent = parent self.child = child self.connection = parent.connection self.part_names = child.parts() self.chains = [ - TableChain(parent, part, direction=direction) + TableChain(parent, part, direction=direction, verbose=verbose) for part in self.part_names ] self.has_link = any([chain.has_link for chain in self.chains]) @@ -806,7 +819,7 @@ def __init__( banned_tables: List[str] = None, **kwargs, ): - if not allow_merge and child and is_merge_table(child): + if not allow_merge and child is not None and is_merge_table(child): raise TypeError("Child is a merge table. Use TableChains instead.") self.parent = self._ensure_name(parent) @@ -817,10 +830,12 @@ def __init__( if not search_restr and not (self.parent and self.child): raise ValueError("Search restriction required to find path.") - super().__init__(seed_table=parent or child, verbose=verbose) + seed_table = parent if isinstance(parent, Table) else child + super().__init__(seed_table=seed_table, verbose=verbose) self.no_visit.update(PERIPHERAL_TABLES) self.no_visit.update(self._ensure_name(banned_tables) or []) + self.no_visit.difference_update([self.parent, self.child]) self.searched_tables = set() self.found_restr = False self.link_type = None @@ -837,14 +852,13 @@ def __init__( if search_restr and not child: self.direction = Direction.DOWN self.leaf = self.parent - if self.leaf: self._set_find_restr(self.leaf, search_restr) self.add_leaf(self.leaf, True, cascade=False, direction=direction) if cascade and search_restr: self.cascade_search() - self.cascade() + self.cascade(restriction=search_restr) self.cascaded = True # --------------------------- Dunder Properties --------------------------- @@ -903,13 +917,10 @@ def all_ft(self) -> List[dj.FreeTable]: @property def path_str(self) -> str: + if not self.path: + return "No link" return self._link_symbol.join([self._camel(t) for t in self.path]) - @property - def endpoint(self) -> str: - """Return endpoint of chain.""" - return self.leaf_ft[0] - # ------------------------------ Graph Nodes ------------------------------ def _set_find_restr(self, table_name, restriction): @@ -962,6 +973,8 @@ def _set_found_vars(self, table): elif self.direction == Direction.DOWN: self.child = table + self._log_truncate(f"FVars: {self._camel(table)}") + self.direction = ~self.direction _ = self.path # Reset path @@ -999,6 +1012,7 @@ def cascade1_search( next_ft = self._get_ft(next_table) if restr_attrs.issubset(set(next_ft.heading.names)): + self._log_truncate(f"Found: {self._camel(next_table)}") self._set_found_vars(next_table) return @@ -1032,7 +1046,13 @@ def find_path(self, directed=True) -> List[str]: List of names in the path. """ source, target = self.parent, self.child - search_graph = self.graph if directed else self.undirect_graph + search_graph = self.graph + + if not directed: + self.connection.dependencies.load() + self.undirect_graph = self.connection.dependencies.to_undirected() + search_graph = self.undirect_graph + search_graph.remove_nodes_from(self.no_visit) try: @@ -1043,9 +1063,12 @@ def find_path(self, directed=True) -> List[str]: self.searched_path = True # No path found, don't search again return None + self._log_truncate(f"Path Found : {path}") + ignore_nodes = self.graph.nodes - set(path) self.no_visit.update(ignore_nodes) + self._log_truncate(f"Ignore : {ignore_nodes}") return path @cached_property @@ -1069,11 +1092,13 @@ def cascade(self, restriction: str = None, direction: Direction = None): _ = self.path - direction = direction or self.direction + direction = Direction(direction) or self.direction if direction == Direction.UP: start, end = self.child, self.parent - else: + elif direction == Direction.DOWN: start, end = self.parent, self.child + else: + raise ValueError(f"Invalid direction: {direction}") self.cascade1( table=start, diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 646e34171..d17f5f4c6 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -10,7 +10,7 @@ from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html -from datajoint.utils import from_camel_case, to_camel_case +from datajoint.utils import from_camel_case, get_master, to_camel_case from IPython.core.display import HTML from spyglass.utils.logging import logger @@ -32,16 +32,21 @@ def trim_def(definition): r"\n\s*\n", "\n", re_sub(r"#.*\n", "\n", definition.strip()) ) + if isinstance(table, str): + table = dj.FreeTable(dj.conn(), table) if not isinstance(table, dj.Table): return False + if get_master(table.full_table_name): + return False # Part tables are not merge tables if not table.is_declared: if tbl_def := getattr(table, "definition", None): return trim_def(MERGE_DEFINITION) == trim_def(tbl_def) logger.warning(f"Cannot determine merge table status for {table}") return True - return table.primary_key == [ + ret = table.primary_key == [ RESERVED_PRIMARY_KEY ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY] + return ret class Merge(dj.Manual): diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 758b39a08..b101a2f5b 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -70,6 +70,8 @@ class SpyglassMixin: _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure + _banned_search_tables = set() # Tables to avoid in restrict_by + def __init__(self, *args, **kwargs): """Initialize SpyglassMixin. @@ -790,6 +792,25 @@ def __rshift__(self, restriction) -> QueryExpression: """ return self.restrict_by(restriction, direction="down") + def _ensure_names(self, tables) -> List[str]: + """Ensure table is a string in a list.""" + if not isinstance(tables, (list, tuple, set)): + tables = [tables] + for table in tables: + return [getattr(table, "full_table_name", table) for t in tables] + + def ban_search_table(self, table): + """Ban table from search in restrict_by.""" + self._banned_search_tables.update(self._ensure_names(table)) + + def unban_search_table(self, table): + """Unban table from search in restrict_by.""" + self._banned_search_tables.difference_update(self._ensure_names(table)) + + def see_banned_tables(self): + """Print banned tables.""" + logger.info(f"Banned tables: {self._banned_search_tables}") + def restrict_by( self, restriction: str = True, @@ -801,24 +822,15 @@ def restrict_by( """Restrict self based on up/downstream table. If fails to restrict table, the shortest path may not have been correct. - Check the output to see where the restriction was applied. - Try the following: + If there's a different path that should be taken, ban unwanted tables. - >>> from spyglass.utils.dj_graph import TableChain - >>> - >>> my_chain = TableChain( - >>> child=MyChildTable(), # or parent=MyParentTable - >>> search_restr=restriction, - >>> allow_merge=True, # If child is Merge - >>> verbose=True, # Detailed output - >>> banned_tables=[UnwantedTable1, UnwantedTable2], - >>> ) + >>> my_table = MyTable() # must be instantced + >>> my_table.ban_search_table(UnwantedTable1) + >>> my_table.ban_search_table([UnwantedTable2, UnwantedTable3]) + >>> my_table.unban_search_table(UnwantedTable3) + >>> my_table.see_banned_tables() >>> - >>> my_chain.endpoint # for the table that meets the restriction - >>> my_chain.all_ft # for all restricted tables in the chain - - When providing a restriction of the parent, use 'up' direction. When - providing a restriction of the child, use 'down' direction. + >>> my_table << my_restriction Parameters ---------- @@ -838,7 +850,10 @@ def restrict_by( Restricted version of present table or FindKeyGraph object. If return_graph, use all_ft attribute to see all tables in cascade. """ - from spyglass.utils.dj_graph import TableChain # noqa: F401 + from spyglass.utils.dj_graph import ( + TableChain, + TableChains, + ) # noqa: F401 if restriction is True: return self @@ -862,6 +877,7 @@ def restrict_by( child=child, direction=direction, search_restr=restriction, + banned_tables=self._banned_search_tables, allow_merge=True, cascade=True, verbose=verbose, @@ -874,8 +890,8 @@ def restrict_by( ret = graph.leaf_ft[0] if len(ret) == len(self) or len(ret) == 0: logger.warning( - f"Failed to restrict with path: {graph.path_str}\n" - + "See help(YourTable.restrict_by)" + f"Failed to restrict with path: {graph.path_str}\n\t" + + "See `help(YourTable.restrict_by)`" ) return ret diff --git a/tests/common/test_device.py b/tests/common/test_device.py index 49bbd9027..19103cf98 100644 --- a/tests/common/test_device.py +++ b/tests/common/test_device.py @@ -2,7 +2,7 @@ from numpy import array_equal -def test_invalid_device(common, populate_exception): +def test_invalid_device(common, populate_exception, mini_insert): device_dict = common.DataAcquisitionDevice.fetch(as_dict=True)[0] device_dict["other"] = "invalid" with pytest.raises(populate_exception): diff --git a/tests/conftest.py b/tests/conftest.py index a3d4d681a..03cb27c6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from time import sleep as tsleep import datajoint as dj +import numpy as np import pynwb import pytest from datajoint.logging import logger as dj_logger @@ -288,8 +289,13 @@ def load_config(dj_conn, base_dir): @pytest.fixture(autouse=True, scope="session") -def mini_insert(mini_path, teardown, server, load_config): - from spyglass.common import LabMember, Nwbfile, Session # noqa: E402 +def mini_insert(mini_path, mini_content, teardown, server, load_config): + from spyglass.common import ( # noqa: E402 + DataAcquisitionDevice, + LabMember, + Nwbfile, + Session, + ) from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.spikesorting.spikesorting_merge import ( # noqa: E402 SpikeSortingOutput, @@ -313,6 +319,7 @@ def mini_insert(mini_path, teardown, server, load_config): if len(Nwbfile()) != 0: dj_logger.warning("Skipping insert, use existing data.") else: + DataAcquisitionDevice().insert_from_nwbfile(mini_content, {}) insert_sessions(mini_path.name) if len(Session()) == 0: @@ -379,6 +386,20 @@ def lfp_band(lfp): return lfp_band +@pytest.fixture(scope="session") +def sgl(common): + from spyglass import linearization + + yield linearization + + +@pytest.fixture(scope="session") +def sgpl(sgl): + from spyglass.linearization import v1 + + yield v1 + + @pytest.fixture(scope="session") def populate_exception(): from spyglass.common.errors import PopulateException @@ -499,9 +520,258 @@ def pos_merge_key(pos_merge, trodes_pos_v1, trodes_sel_keys): yield pos_merge.merge_get_part(trodes_sel_keys[-1]).fetch1("KEY") +# ---------------------- FIXTURES, LINEARIZATION TABLES ---------------------- +# ---------------------- Note: Used to test RestrGraph ----------------------- + + +@pytest.fixture(scope="session") +def pos_lin_key(trodes_sel_keys): + yield trodes_sel_keys[-1] + + +@pytest.fixture(scope="session") +def position_info(pos_merge, pos_merge_key): + yield (pos_merge & {"merge_id": pos_merge_key}).fetch1_dataframe() + + +@pytest.fixture(scope="session") +def track_graph_key(): + yield {"track_graph_name": "6 arm"} + + +@pytest.fixture(scope="session") +def track_graph(teardown, sgpl, track_graph_key): + node_positions = np.array( + [ + (79.910, 216.720), # top left well 0 + (132.031, 187.806), # top middle intersection 1 + (183.718, 217.713), # top right well 2 + (132.544, 132.158), # middle intersection 3 + (87.202, 101.397), # bottom left intersection 4 + (31.340, 126.110), # middle left well 5 + (180.337, 104.799), # middle right intersection 6 + (92.693, 42.345), # bottom left well 7 + (183.784, 45.375), # bottom right well 8 + (231.338, 136.281), # middle right well 9 + ] + ) + + edges = np.array( + [ + (0, 1), + (1, 2), + (1, 3), + (3, 4), + (4, 5), + (3, 6), + (6, 9), + (4, 7), + (6, 8), + ] + ) + + linear_edge_order = [ + (3, 6), + (6, 8), + (6, 9), + (3, 1), + (1, 2), + (1, 0), + (3, 4), + (4, 5), + (4, 7), + ] + linear_edge_spacing = 15 + + sgpl.TrackGraph.insert1( + { + **track_graph_key, + "environment": track_graph_key["track_graph_name"], + "node_positions": node_positions, + "edges": edges, + "linear_edge_order": linear_edge_order, + "linear_edge_spacing": linear_edge_spacing, + }, + skip_duplicates=True, + ) + + yield sgpl.TrackGraph & {"track_graph_name": "6 arm"} + if teardown: + sgpl.TrackGraph().delete(safemode=False) + + +@pytest.fixture(scope="session") +def lin_param_key(): + yield {"linearization_param_name": "default"} + + +@pytest.fixture(scope="session") +def lin_params( + teardown, + sgpl, + lin_param_key, +): + param_table = sgpl.LinearizationParameters() + param_table.insert1(lin_param_key, skip_duplicates=True) + yield param_table + + +@pytest.fixture(scope="session") +def lin_sel_key( + pos_merge_key, track_graph_key, lin_param_key, lin_params, track_graph +): + yield { + "pos_merge_id": pos_merge_key["merge_id"], + **track_graph_key, + **lin_param_key, + } + + +@pytest.fixture(scope="session") +def lin_sel(teardown, sgpl, lin_sel_key): + sel_table = sgpl.LinearizationSelection() + sel_table.insert1(lin_sel_key, skip_duplicates=True) + yield sel_table + if teardown: + sel_table.delete(safemode=False) + + +@pytest.fixture(scope="session") +def lin_v1(teardown, sgpl, lin_sel): + v1 = sgpl.LinearizedPositionV1() + v1.populate() + yield v1 + if teardown: + v1.delete(safemode=False) + + +@pytest.fixture(scope="session") +def lin_merge_key(lin_merge, lin_v1, lin_sel_key): + yield lin_merge.merge_get_part(lin_sel_key).fetch1("KEY") + + # --------------------------- FIXTURES, LFP TABLES --------------------------- +# ---------------- Note: LFPOuput is used to test RestrGraph ----------------- @pytest.fixture(scope="module") def lfp_band_v1(lfp_band): yield lfp_band.LFPBandV1() + + +@pytest.fixture(scope="session") +def firfilters_table(common): + return common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def electrodegroup_table(lfp): + return lfp.v1.LFPElectrodeGroup() + + +@pytest.fixture(scope="session") +def lfp_constants(common, mini_copy_name, mini_dict): + n_delay = 9 + lfp_electrode_group_name = "test" + orig_list_name = "01_s1" + orig_valid_times = ( + common.IntervalList + & mini_dict + & f"interval_list_name = '{orig_list_name}'" + ).fetch1("valid_times") + new_list_name = orig_list_name + f"_first{n_delay}" + new_list_key = { + "nwb_file_name": mini_copy_name, + "interval_list_name": new_list_name, + "valid_times": np.asarray( + [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] + ), + } + + yield dict( + lfp_electrode_ids=[0], + lfp_electrode_group_name=lfp_electrode_group_name, + lfp_eg_key={ + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + }, + n_delay=n_delay, + orig_interval_list_name=orig_list_name, + orig_valid_times=orig_valid_times, + interval_list_name=new_list_name, + interval_key=new_list_key, + filter1_name="LFP 0-400 Hz", + filter_sampling_rate=30_000, + filter2_name="Theta 5-11 Hz", + lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes + lfp_band_sampling_rate=100, # desired sampling rate + ) + + +@pytest.fixture(scope="session") +def add_electrode_group( + firfilters_table, + electrodegroup_table, + mini_copy_name, + lfp_constants, +): + firfilters_table.create_standard_filters() + group_name = lfp_constants.get("lfp_electrode_group_name") + electrodegroup_table.create_lfp_electrode_group( + nwb_file_name=mini_copy_name, + group_name=group_name, + electrode_list=np.array(lfp_constants.get("lfp_electrode_ids")), + ) + assert len( + electrodegroup_table & {"lfp_electrode_group_name": group_name} + ), "Failed to add LFPElectrodeGroup." + yield + + +@pytest.fixture(scope="session") +def add_interval(common, lfp_constants): + common.IntervalList.insert1( + lfp_constants.get("interval_key"), skip_duplicates=True + ) + yield lfp_constants.get("interval_list_name") + + +@pytest.fixture(scope="session") +def add_selection( + lfp, common, add_electrode_group, add_interval, lfp_constants +): + lfp_s_key = { + **lfp_constants.get("lfp_eg_key"), + "target_interval_list_name": add_interval, + "filter_name": lfp_constants.get("filter1_name"), + "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), + } + lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) + yield lfp_s_key + + +@pytest.fixture(scope="session") +def lfp_s_key(lfp_constants, mini_copy_name): + yield { + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_constants.get( + "lfp_electrode_group_name" + ), + "target_interval_list_name": lfp_constants.get("interval_list_name"), + } + + +@pytest.fixture(scope="session") +def populate_lfp(lfp, add_selection, lfp_s_key): + lfp.v1.LFPV1().populate(add_selection) + yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} + + +@pytest.fixture(scope="session") +def lfp_merge_key(populate_lfp): + yield populate_lfp + + +@pytest.fixture(scope="session") +def lfp_v1_key(lfp, lfp_s_key): + yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index e318610ec..e62a03dea 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -1,126 +1,7 @@ -import numpy as np import pytest from pynwb import NWBHDF5IO -@pytest.fixture(scope="session") -def firfilters_table(common): - return common.FirFilterParameters() - - -@pytest.fixture(scope="session") -def electrodegroup_table(lfp): - return lfp.v1.LFPElectrodeGroup() - - -@pytest.fixture(scope="session") -def lfp_constants(common, mini_copy_name, mini_dict): - n_delay = 9 - lfp_electrode_group_name = "test" - orig_list_name = "01_s1" - orig_valid_times = ( - common.IntervalList - & mini_dict - & f"interval_list_name = '{orig_list_name}'" - ).fetch1("valid_times") - new_list_name = orig_list_name + f"_first{n_delay}" - new_list_key = { - "nwb_file_name": mini_copy_name, - "interval_list_name": new_list_name, - "valid_times": np.asarray( - [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] - ), - } - - yield dict( - lfp_electrode_ids=[0], - lfp_electrode_group_name=lfp_electrode_group_name, - lfp_eg_key={ - "nwb_file_name": mini_copy_name, - "lfp_electrode_group_name": lfp_electrode_group_name, - }, - n_delay=n_delay, - orig_interval_list_name=orig_list_name, - orig_valid_times=orig_valid_times, - interval_list_name=new_list_name, - interval_key=new_list_key, - filter1_name="LFP 0-400 Hz", - filter_sampling_rate=30_000, - filter2_name="Theta 5-11 Hz", - lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes - lfp_band_sampling_rate=100, # desired sampling rate - ) - - -@pytest.fixture(scope="session") -def add_electrode_group( - firfilters_table, - electrodegroup_table, - mini_copy_name, - lfp_constants, -): - firfilters_table.create_standard_filters() - group_name = lfp_constants.get("lfp_electrode_group_name") - electrodegroup_table.create_lfp_electrode_group( - nwb_file_name=mini_copy_name, - group_name=group_name, - electrode_list=np.array(lfp_constants.get("lfp_electrode_ids")), - ) - assert len( - electrodegroup_table & {"lfp_electrode_group_name": group_name} - ), "Failed to add LFPElectrodeGroup." - yield - - -@pytest.fixture(scope="session") -def add_interval(common, lfp_constants): - common.IntervalList.insert1( - lfp_constants.get("interval_key"), skip_duplicates=True - ) - yield lfp_constants.get("interval_list_name") - - -@pytest.fixture(scope="session") -def add_selection( - lfp, common, add_electrode_group, add_interval, lfp_constants -): - lfp_s_key = { - **lfp_constants.get("lfp_eg_key"), - "target_interval_list_name": add_interval, - "filter_name": lfp_constants.get("filter1_name"), - "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), - } - lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) - yield lfp_s_key - - -@pytest.fixture(scope="session") -def lfp_s_key(lfp_constants, mini_copy_name): - yield { - "nwb_file_name": mini_copy_name, - "lfp_electrode_group_name": lfp_constants.get( - "lfp_electrode_group_name" - ), - "target_interval_list_name": lfp_constants.get("interval_list_name"), - } - - -@pytest.fixture(scope="session") -def populate_lfp(lfp, add_selection, lfp_s_key): - lfp.v1.LFPV1().populate(add_selection) - yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} - - -@pytest.fixture(scope="session") -def lfp_merge_key(populate_lfp): - yield populate_lfp - - -@pytest.fixture(scope="session") -def lfp_v1_key(lfp, lfp_s_key): - yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") - - @pytest.fixture(scope="module") def lfp_analysis_raw(common, lfp, populate_lfp, mini_dict): abs_path = (common.AnalysisNwbfile * lfp.v1.LFPV1 & mini_dict).fetch( diff --git a/tests/linearization/conftest.py b/tests/linearization/conftest.py deleted file mode 100644 index 505dcc816..000000000 --- a/tests/linearization/conftest.py +++ /dev/null @@ -1,142 +0,0 @@ -import numpy as np -import pytest - - -@pytest.fixture(scope="session") -def sgl(common): - from spyglass import linearization - - yield linearization - - -@pytest.fixture(scope="session") -def sgpl(sgl): - from spyglass.linearization import v1 - - yield v1 - - -@pytest.fixture(scope="session") -def pos_lin_key(trodes_sel_keys): - yield trodes_sel_keys[-1] - - -@pytest.fixture(scope="session") -def position_info(pos_merge, pos_merge_key): - yield (pos_merge & {"merge_id": pos_merge_key}).fetch1_dataframe() - - -@pytest.fixture(scope="session") -def track_graph_key(): - yield {"track_graph_name": "6 arm"} - - -@pytest.fixture(scope="session") -def track_graph(teardown, sgpl, track_graph_key): - node_positions = np.array( - [ - (79.910, 216.720), # top left well 0 - (132.031, 187.806), # top middle intersection 1 - (183.718, 217.713), # top right well 2 - (132.544, 132.158), # middle intersection 3 - (87.202, 101.397), # bottom left intersection 4 - (31.340, 126.110), # middle left well 5 - (180.337, 104.799), # middle right intersection 6 - (92.693, 42.345), # bottom left well 7 - (183.784, 45.375), # bottom right well 8 - (231.338, 136.281), # middle right well 9 - ] - ) - - edges = np.array( - [ - (0, 1), - (1, 2), - (1, 3), - (3, 4), - (4, 5), - (3, 6), - (6, 9), - (4, 7), - (6, 8), - ] - ) - - linear_edge_order = [ - (3, 6), - (6, 8), - (6, 9), - (3, 1), - (1, 2), - (1, 0), - (3, 4), - (4, 5), - (4, 7), - ] - linear_edge_spacing = 15 - - sgpl.TrackGraph.insert1( - { - **track_graph_key, - "environment": track_graph_key["track_graph_name"], - "node_positions": node_positions, - "edges": edges, - "linear_edge_order": linear_edge_order, - "linear_edge_spacing": linear_edge_spacing, - }, - skip_duplicates=True, - ) - - yield sgpl.TrackGraph & {"track_graph_name": "6 arm"} - if teardown: - sgpl.TrackGraph().delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_param_key(): - yield {"linearization_param_name": "default"} - - -@pytest.fixture(scope="session") -def lin_params( - teardown, - sgpl, - lin_param_key, -): - param_table = sgpl.LinearizationParameters() - param_table.insert1(lin_param_key, skip_duplicates=True) - yield param_table - - -@pytest.fixture(scope="session") -def lin_sel_key( - pos_merge_key, track_graph_key, lin_param_key, lin_params, track_graph -): - yield { - "pos_merge_id": pos_merge_key["merge_id"], - **track_graph_key, - **lin_param_key, - } - - -@pytest.fixture(scope="session") -def lin_sel(teardown, sgpl, lin_sel_key): - sel_table = sgpl.LinearizationSelection() - sel_table.insert1(lin_sel_key, skip_duplicates=True) - yield sel_table - if teardown: - sel_table.delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_v1(teardown, sgpl, lin_sel): - v1 = sgpl.LinearizedPositionV1() - v1.populate() - yield v1 - if teardown: - v1.delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_merge_key(lin_merge, lin_sel_key): - yield lin_merge.merge_get_part(lin_sel_key).fetch1("KEY") diff --git a/tests/linearization/test_lin.py b/tests/linearization/test_lin.py index 4225ad5bf..a5db28d9a 100644 --- a/tests/linearization/test_lin.py +++ b/tests/linearization/test_lin.py @@ -9,4 +9,4 @@ def test_fetch1_dataframe(lin_v1, lin_merge, lin_merge_key): assert hash_df == hash_exp, "Dataframe differs from expected" -## Todo: Add more tests of this pipeline, not just the fetch1_dataframe method +# TODO: Add more tests of this pipeline, not just the fetch1_dataframe method diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index 3cc085471..66d9772c3 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -9,10 +9,15 @@ def TableChain(): return TableChain +def full_to_camel(t): + return to_camel_case(t.split(".")[-1].strip("`")) + + def test_chains_repr(chains): """Test that the repr of a TableChains object is as expected.""" repr_got = repr(chains) - repr_exp = "\n".join([str(c) for c in chains.chains]) + chain_st = ",\n\t".join([str(c) for c in chains.chains]) + "\n" + repr_exp = f"TableChains(\n\t{chain_st})" assert repr_got == repr_exp, "Unexpected repr of TableChains object." @@ -32,11 +37,13 @@ def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain): def test_chain_str(chain): """Test that the str of a TableChain object is as expected.""" chain = chain - parent = to_camel_case(chain.parent.table_name) - child = to_camel_case(chain.child.table_name) str_got = str(chain) - str_exp = parent + chain._link_symbol + child + str_exp = ( + full_to_camel(chain.parent) + + chain._link_symbol + + full_to_camel(chain.child) + ) assert str_got == str_exp, "Unexpected str of TableChain object." @@ -44,7 +51,9 @@ def test_chain_str(chain): def test_chain_repr(chain): """Test that the repr of a TableChain object is as expected.""" repr_got = repr(chain) - repr_ext = "Chain: " + chain._link_symbol.join(chain.path) + repr_ext = "Chain: " + chain._link_symbol.join( + [full_to_camel(t) for t in chain.path] + ) assert repr_got == repr_ext, "Unexpected repr of TableChain object." diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 33a8c8f1a..387346163 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -1,7 +1,5 @@ import pytest -from . import schema_graph as sg - @pytest.fixture(scope="session") def leaf(lin_merge): @@ -9,15 +7,17 @@ def leaf(lin_merge): @pytest.fixture(scope="session") -def restr_graph(leaf): +def restr_graph(leaf, verbose, lin_merge_key): from spyglass.utils.dj_graph import RestrGraph + _ = lin_merge_key # linearization merge table populated + yield RestrGraph( seed_table=leaf, table_name=leaf.full_table_name, restriction=True, cascade=True, - verbose=True, + verbose=verbose, ) @@ -31,8 +31,9 @@ def test_rg_repr(restr_graph, leaf): def test_rg_ft(restr_graph): """Test FreeTable attribute of RestrGraph.""" - assert len(restr_graph.leaf_ft) == 1, "Unexpected number of leaf tables." - assert len(restr_graph.all_ft) == 9, "Unexpected number of cascaded tables." + assert len(restr_graph.leaf_ft) == 1, "Unexpected # of leaf tables." + assert len(restr_graph.all_ft) == 15, "Unexpected # of cascaded tables." + assert len(restr_graph["spatial"]) == 2, "Unexpected cascaded table length." def test_rg_restr_ft(restr_graph): @@ -44,7 +45,7 @@ def test_rg_restr_ft(restr_graph): def test_rg_file_paths(restr_graph): """Test collection of upstream file paths.""" paths = [p.get("file_path") for p in restr_graph.file_paths] - assert len(paths) == 1, "Unexpected number of file paths." + assert len(paths) == 2, "Unexpected number of file paths." @pytest.fixture(scope="session") @@ -87,7 +88,7 @@ def restr_graph_root(restr_graph, common, lfp_band): def test_rg_root(restr_graph_root): assert ( - len(restr_graph_root.all_ft) == 29 + len(restr_graph_root.all_ft) == 25 ), "Unexpected number of cascaded tables." diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 8b1d41f4a..b7c12c892 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -43,15 +43,17 @@ def test_merge_detect(Nwbfile, pos_merge_tables): ), "Merges not detected by mixin." -def test_merge_chain_join(Nwbfile, pos_merge_tables): +def test_merge_chain_join(Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key): """Test that the mixin can join merge chains.""" + _ = lin_v1, lfp_merge_key # merge tables populated + all_chains = [ chains.cascade(True, direction="down") for chains in Nwbfile._merge_chains.values() ] end_len = [len(chain[0]) for chain in all_chains if chain] - assert end_len == [1, 1, 2], "Merge chains not joined correctly." + assert sum(end_len) == 4, "Merge chains not joined correctly." def test_get_chain(Nwbfile, pos_merge_tables): @@ -70,17 +72,17 @@ def test_ddm_warning(Nwbfile, caplog): assert "No merge deletes found" in caplog.text, "No warning issued." -def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables): +def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables, lin_v1): """Test that the mixin can dry run delete_downstream_merge.""" + _ = lin_v1 # merge tables populated + pos_output_name = pos_merge_tables[0].full_table_name + param_field = "trodes_pos_params_name" trodes_params = sgp.v1.TrodesPosParams() - rft = next( - iter( - (trodes_params & f'{param_field} LIKE "%ups%"').ddm( - reload_cache=True, dry_run=True, return_parts=True - ) - ) - )[0] + + rft = (trodes_params & f'{param_field} LIKE "%ups%"').ddm( + reload_cache=True, dry_run=True, return_parts=False + )[pos_output_name][0] assert len(rft) == 1, "ddm did not return restricted table." table_name = [p for p in pos_merge_tables[0].parts() if "trode" in p][0] From 4e26df945287d8e86b56c4a5ba3455b9916bf487 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 May 2024 13:57:56 -0500 Subject: [PATCH 13/17] Revert pytest options --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9ed16ab6..ffb8d0df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,8 +121,8 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ "-sv", - "--sw", # stepwise: resume with next test after failure - "--pdb", # drop into debugger on failure + # "--sw", # stepwise: resume with next test after failure + # "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass From ad7f6c2da7ae425db7d185d241b2b3bedc2724d1 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 6 May 2024 13:46:46 -0500 Subject: [PATCH 14/17] Fix failing tests --- src/spyglass/utils/dj_merge_tables.py | 7 +- tests/conftest.py | 14 +- tests/container.py | 2 +- tests/utils/conftest.py | 204 ++++++++++++++++++++++++-- tests/utils/schema_graph.py | 179 ---------------------- tests/utils/test_graph.py | 10 +- 6 files changed, 208 insertions(+), 208 deletions(-) delete mode 100644 tests/utils/schema_graph.py diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index d17f5f4c6..9f172fbd9 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -41,12 +41,13 @@ def trim_def(definition): if not table.is_declared: if tbl_def := getattr(table, "definition", None): return trim_def(MERGE_DEFINITION) == trim_def(tbl_def) - logger.warning(f"Cannot determine merge table status for {table}") + logger.warning( + f"Cannot determine merge table status for {table.table_name}" + ) return True - ret = table.primary_key == [ + return table.primary_key == [ RESERVED_PRIMARY_KEY ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY] - return ret class Merge(dj.Manual): diff --git a/tests/conftest.py b/tests/conftest.py index 03cb27c6f..7950854d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -284,18 +284,15 @@ def load_config(dj_conn, base_dir): from spyglass.settings import SpyglassConfig yield SpyglassConfig().load_config( - base_dir=base_dir, test_mode=True, force_reload=True + base_dir=base_dir, debug_mode=False, test_mode=True, force_reload=True ) @pytest.fixture(autouse=True, scope="session") -def mini_insert(mini_path, mini_content, teardown, server, load_config): - from spyglass.common import ( # noqa: E402 - DataAcquisitionDevice, - LabMember, - Nwbfile, - Session, - ) +def mini_insert( + dj_conn, mini_path, mini_content, teardown, server, load_config +): + from spyglass.common import LabMember, Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.spikesorting.spikesorting_merge import ( # noqa: E402 SpikeSortingOutput, @@ -319,7 +316,6 @@ def mini_insert(mini_path, mini_content, teardown, server, load_config): if len(Nwbfile()) != 0: dj_logger.warning("Skipping insert, use existing data.") else: - DataAcquisitionDevice().insert_from_nwbfile(mini_content, {}) insert_sessions(mini_path.name) if len(Session()) == 0: diff --git a/tests/container.py b/tests/container.py index 04e176fee..fa26f1c46 100644 --- a/tests/container.py +++ b/tests/container.py @@ -193,7 +193,7 @@ def creds(self): "database.user": self.user, "database.port": int(self.port), "safemode": "false", - "custom": {"test_mode": True}, + "custom": {"test_mode": True, "debug_mode": False}, } @property diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 4c6209de2..52557862d 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -1,8 +1,6 @@ import datajoint as dj import pytest -from . import schema_graph - @pytest.fixture(scope="module") def merge_table(pos_merge_tables): @@ -61,25 +59,209 @@ def no_link_chain(Nwbfile): @pytest.fixture(scope="module") -def graph_tables(dj_conn): - lg = schema_graph.LOCALS_GRAPH +def _Merge(): + """Return the _Merge class.""" + from spyglass.utils import _Merge + + yield _Merge + + +@pytest.fixture(scope="module") +def SpyglassMixin(): + """Return a mixin class.""" + from spyglass.utils import SpyglassMixin + + yield SpyglassMixin + + +@pytest.fixture(scope="module") +def graph_schema(SpyglassMixin, _Merge): + """ + NOTE: Must declare tables within fixture to avoid loading config defaults. + """ + parent_id = range(10) + parent_attr = [i + 10 for i in range(2, 12)] + other_id = range(9) + other_attr = [i + 10 for i in range(3, 12)] + intermediate_id = range(2, 10) + intermediate_attr = [i + 10 for i in range(4, 12)] + pk_id = range(3, 10) + pk_attr = [i + 10 for i in range(5, 12)] + sk_id = range(6) + sk_attr = [i + 10 for i in range(6, 12)] + pk_sk_id = range(5) + pk_sk_attr = [i + 10 for i in range(7, 12)] + pk_alias_id = range(4) + pk_alias_attr = [i + 10 for i in range(8, 12)] + sk_alias_id = range(3) + sk_alias_attr = [i + 10 for i in range(9, 12)] + + def offset(gen, offset): + return list(gen)[offset:] + + class ParentNode(SpyglassMixin, dj.Lookup): + definition = """ + parent_id: int + --- + parent_attr : int + """ + contents = [(i, j) for i, j in zip(parent_id, parent_attr)] + + class OtherParentNode(SpyglassMixin, dj.Lookup): + definition = """ + other_id: int + --- + other_attr : int + """ + contents = [(i, j) for i, j in zip(other_id, other_attr)] + + class IntermediateNode(SpyglassMixin, dj.Lookup): + definition = """ + intermediate_id: int + --- + -> ParentNode + intermediate_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip( + intermediate_id, offset(parent_id, 1), intermediate_attr + ) + ] + + class PkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_id: int + -> IntermediateNode + --- + pk_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip(pk_id, offset(intermediate_id, 2), pk_attr) + ] + + class SkNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_id: int + --- + -> IntermediateNode + sk_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip(sk_id, offset(intermediate_id, 3), sk_attr) + ] + + class PkSkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_sk_id: int + -> IntermediateNode + --- + -> OtherParentNode + pk_sk_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_sk_id, offset(intermediate_id, 4), other_id, pk_sk_attr + ) + ] + + class PkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_alias_id: int + -> PkNode.proj(fk_pk_id='pk_id') + --- + pk_alias_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_alias_id, + offset(pk_id, 1), + offset(intermediate_id, 3), + pk_alias_attr, + ) + ] + + class SkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_alias_id: int + --- + -> SkNode.proj(fk_sk_id='sk_id') + -> PkSkNode + sk_alias_attr : int + """ + contents = [ + (i, j, k, m, n) + for i, j, k, m, n in zip( + sk_alias_id, + offset(sk_id, 2), + offset(pk_sk_id, 1), + offset(intermediate_id, 5), + sk_alias_attr, + ) + ] + + class MergeOutput(_Merge, SpyglassMixin): + definition = """ + merge_id: uuid + --- + source: varchar(32) + """ + + class PkNode(dj.Part): + definition = """ + -> MergeOutput + --- + -> PkNode + """ + + class MergeChild(SpyglassMixin, dj.Manual): + definition = """ + -> MergeOutput + merge_child_id: int + --- + merge_child_attr: int + """ + + yield { + "ParentNode": ParentNode, + "OtherParentNode": OtherParentNode, + "IntermediateNode": IntermediateNode, + "PkNode": PkNode, + "SkNode": SkNode, + "PkSkNode": PkSkNode, + "PkAliasNode": PkAliasNode, + "SkAliasNode": SkAliasNode, + "MergeOutput": MergeOutput, + "MergeChild": MergeChild, + } + + +@pytest.fixture(scope="module") +def graph_tables(dj_conn, graph_schema): - schema = dj.Schema(context=lg) + schema = dj.Schema(context=graph_schema) - for table in lg.values(): + for table in graph_schema.values(): schema(table) schema.activate("test_graph", connection=dj_conn) - merge_keys = lg["PkNode"].fetch("KEY", offset=1, as_dict=True) - lg["MergeOutput"].insert(merge_keys, skip_duplicates=True) - merge_child_keys = lg["MergeOutput"].merge_fetch(True, "merge_id", offset=1) + # Merge inserts after declaring tables + merge_keys = graph_schema["PkNode"].fetch("KEY", offset=1, as_dict=True) + graph_schema["MergeOutput"].insert(merge_keys, skip_duplicates=True) + merge_child_keys = graph_schema["MergeOutput"].merge_fetch( + True, "merge_id", offset=1 + ) merge_child_inserts = [ (i, j, k + 10) for i, j, k in zip(merge_child_keys, range(4), range(10, 15)) ] - lg["MergeChild"].insert(merge_child_inserts, skip_duplicates=True) + graph_schema["MergeChild"].insert(merge_child_inserts, skip_duplicates=True) - yield schema_graph.LOCALS_GRAPH + yield graph_schema schema.drop(force=True) diff --git a/tests/utils/schema_graph.py b/tests/utils/schema_graph.py deleted file mode 100644 index 659518ceb..000000000 --- a/tests/utils/schema_graph.py +++ /dev/null @@ -1,179 +0,0 @@ -from inspect import isclass as inspect_isclass - -import datajoint as dj - -from spyglass.utils import SpyglassMixin, _Merge - -# Ranges are offset from one another to create unique list of entries for each -# table while respecting the foreign key constraints. - -parent_id = range(10) -parent_attr = [i + 10 for i in range(2, 12)] - -other_id = range(9) -other_attr = [i + 10 for i in range(3, 12)] - -intermediate_id = range(2, 10) -intermediate_attr = [i + 10 for i in range(4, 12)] - -pk_id = range(3, 10) -pk_attr = [i + 10 for i in range(5, 12)] - -sk_id = range(6) -sk_attr = [i + 10 for i in range(6, 12)] - -pk_sk_id = range(5) -pk_sk_attr = [i + 10 for i in range(7, 12)] - -pk_alias_id = range(4) -pk_alias_attr = [i + 10 for i in range(8, 12)] - -sk_alias_id = range(3) -sk_alias_attr = [i + 10 for i in range(9, 12)] - - -def offset(gen, offset): - return list(gen)[offset:] - - -class ParentNode(SpyglassMixin, dj.Lookup): - definition = """ - parent_id: int - --- - parent_attr : int - """ - contents = [(i, j) for i, j in zip(parent_id, parent_attr)] - - -class OtherParentNode(SpyglassMixin, dj.Lookup): - definition = """ - other_id: int - --- - other_attr : int - """ - contents = [(i, j) for i, j in zip(other_id, other_attr)] - - -class IntermediateNode(SpyglassMixin, dj.Lookup): - definition = """ - intermediate_id: int - --- - -> ParentNode - intermediate_attr : int - """ - contents = [ - (i, j, k) - for i, j, k in zip( - intermediate_id, offset(parent_id, 1), intermediate_attr - ) - ] - - -class PkNode(SpyglassMixin, dj.Lookup): - definition = """ - pk_id: int - -> IntermediateNode - --- - pk_attr : int - """ - contents = [ - (i, j, k) for i, j, k in zip(pk_id, offset(intermediate_id, 2), pk_attr) - ] - - -class SkNode(SpyglassMixin, dj.Lookup): - definition = """ - sk_id: int - --- - -> IntermediateNode - sk_attr : int - """ - contents = [ - (i, j, k) for i, j, k in zip(sk_id, offset(intermediate_id, 3), sk_attr) - ] - - -class PkSkNode(SpyglassMixin, dj.Lookup): - definition = """ - pk_sk_id: int - -> IntermediateNode - --- - -> OtherParentNode - pk_sk_attr : int - """ - contents = [ - (i, j, k, m) - for i, j, k, m in zip( - pk_sk_id, offset(intermediate_id, 4), other_id, pk_sk_attr - ) - ] - - -class PkAliasNode(SpyglassMixin, dj.Lookup): - definition = """ - pk_alias_id: int - -> PkNode.proj(fk_pk_id='pk_id') - --- - pk_alias_attr : int - """ - contents = [ - (i, j, k, m) - for i, j, k, m in zip( - pk_alias_id, - offset(pk_id, 1), - offset(intermediate_id, 3), - pk_alias_attr, - ) - ] - - -class SkAliasNode(SpyglassMixin, dj.Lookup): - definition = """ - sk_alias_id: int - --- - -> SkNode.proj(fk_sk_id='sk_id') - -> PkSkNode - sk_alias_attr : int - """ - contents = [ - (i, j, k, m, n) - for i, j, k, m, n in zip( - sk_alias_id, - offset(sk_id, 2), - offset(pk_sk_id, 1), - offset(intermediate_id, 5), - sk_alias_attr, - ) - ] - - -class MergeOutput(_Merge, SpyglassMixin): - definition = """ - merge_id: uuid - --- - source: varchar(32) - """ - - class PkNode(dj.Part): - definition = """ - -> MergeOutput - --- - -> PkNode - """ - - -class MergeChild(SpyglassMixin, dj.Manual): - definition = """ - -> MergeOutput - merge_child_id: int - --- - merge_child_attr: int - """ - - -LOCALS_GRAPH = { - k: v - for k, v in locals().items() - if inspect_isclass(v) and k not in ["SpyglassMixin", "_Merge"] -} -__all__ = list(LOCALS_GRAPH) diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 387346163..d57d312b8 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -38,8 +38,8 @@ def test_rg_ft(restr_graph): def test_rg_restr_ft(restr_graph): """Test get restricted free tables.""" - ft = restr_graph._get_ft(list(restr_graph.visited)[1], with_restr=True) - assert len(ft) == 1, "Unexpected restricted table length." + ft = restr_graph["spatial_series"] + assert len(ft) == 2, "Unexpected restricted table length." def test_rg_file_paths(restr_graph): @@ -73,7 +73,7 @@ def test_add_leaf_restr_ft(restr_graph_new_leaf): @pytest.fixture(scope="session") -def restr_graph_root(restr_graph, common, lfp_band): +def restr_graph_root(restr_graph, common, lfp_band, lin_v1): from spyglass.utils.dj_graph import RestrGraph yield RestrGraph( @@ -88,8 +88,8 @@ def restr_graph_root(restr_graph, common, lfp_band): def test_rg_root(restr_graph_root): assert ( - len(restr_graph_root.all_ft) == 25 - ), "Unexpected number of cascaded tables." + len(restr_graph_root["trodes_pos_v1"]) == 2 + ), "Incomplete cascade from root." @pytest.mark.parametrize( From 9763b5a6e3acd66b3df7a3cea771929a82920df6 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 6 May 2024 15:08:20 -0500 Subject: [PATCH 15/17] Bail on cascade if restr empty --- src/spyglass/utils/dj_graph.py | 7 +++---- tests/utils/test_graph.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index d3f8fab04..ccffd9bd4 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -409,6 +409,9 @@ def cascade1( **data, ) + if next_restr == ["False"]: # Stop cascade if empty restriction + continue + self.cascade1( table=next_table, restriction=next_restr, @@ -628,10 +631,6 @@ def cascade(self, show_progress=None, direction="up") -> None: restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") self.cascade1(table, restr, direction=direction) - if self.to_visit - self.visited: - raise RuntimeError( - "Cascade: FAIL - incomplete cascade. Please post issue." - ) self.cascade_files() self.cascaded = True diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index d57d312b8..29d1f21ad 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -32,7 +32,6 @@ def test_rg_repr(restr_graph, leaf): def test_rg_ft(restr_graph): """Test FreeTable attribute of RestrGraph.""" assert len(restr_graph.leaf_ft) == 1, "Unexpected # of leaf tables." - assert len(restr_graph.all_ft) == 15, "Unexpected # of cascaded tables." assert len(restr_graph["spatial"]) == 2, "Unexpected cascaded table length." From 1032d939dfbf41405b86b6035844428a50fb588f Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Thu, 9 May 2024 14:35:15 -0500 Subject: [PATCH 16/17] Update src/spyglass/utils/dj_mixin.py Co-authored-by: Samuel Bray --- src/spyglass/utils/dj_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 41c4b5570..1a9a42b3a 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -877,7 +877,7 @@ def restrict_by( child=child, direction=direction, search_restr=restriction, - banned_tables=self._banned_search_tables, + banned_tables=list(self._banned_search_tables), allow_merge=True, cascade=True, verbose=verbose, From f8094572d0fb84396c5ab9cc482033d04a7ffde9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 9 May 2024 15:52:16 -0500 Subject: [PATCH 17/17] Permit dict/list-of-dict restr on long-distance restrict --- src/spyglass/utils/dj_graph.py | 20 ++++++++++++++----- src/spyglass/utils/dj_merge_tables.py | 2 +- src/spyglass/utils/dj_mixin.py | 11 +++++------ tests/utils/conftest.py | 28 +++++++++++++++++++++++++++ tests/utils/test_graph.py | 22 +++++++++++++++++++++ 5 files changed, 71 insertions(+), 12 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index ccffd9bd4..5bf3d25d0 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -925,14 +925,24 @@ def path_str(self) -> str: def _set_find_restr(self, table_name, restriction): """Set restr to look for from leaf node.""" if isinstance(restriction, dict): - logger.warning("Using `>>` or `<<`: DICT unreliable, use STR.") + restriction = [restriction] - restr_attrs = set() # modified by make_condition - table_ft = self._get_ft(table_name) - restr_string = make_condition(table_ft, restriction, restr_attrs) + if isinstance(restriction, list) and all( + [isinstance(r, dict) for r in restriction] + ): + restr_attrs = set(key for restr in restriction for key in restr) + find_restr = restriction + elif isinstance(restriction, str): + restr_attrs = set() # modified by make_condition + table_ft = self._get_ft(table_name) + find_restr = make_condition(table_ft, restriction, restr_attrs) + else: + raise ValueError( + f"Invalid restriction type, use STR: {restriction}" + ) - self._set_node(table_name, "find_restr", restr_string) self._set_node(table_name, "restr_attrs", restr_attrs) + self._set_node(table_name, "find_restr", find_restr) def _get_find_restr(self, table) -> Tuple[str, Set[str]]: """Get restr and restr_attrs from leaf node.""" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 9f172fbd9..0b8f16de6 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -814,7 +814,7 @@ def super_delete(self, warn=True, *args, **kwargs): """ if warn: logger.warning("!! Bypassing cautious_delete !!") - self._log_use(start=time(), super_delete=True) + self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 1a9a42b3a..08fa377b3 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -850,17 +850,16 @@ def restrict_by( Restricted version of present table or FindKeyGraph object. If return_graph, use all_ft attribute to see all tables in cascade. """ - from spyglass.utils.dj_graph import ( - TableChain, - TableChains, - ) # noqa: F401 + from spyglass.utils.dj_graph import TableChain # noqa: F401 if restriction is True: return self + try: ret = self.restrict(restriction) # Save time trying first - logger.warning("Restriction valid for this table. Using as is.") - return ret + if len(ret) < len(self): + logger.warning("Restriction valid for this table. Using as is.") + return ret except DataJointError: pass # Could avoid try/except if assert_join_compatible return bool logger.debug("Restriction not valid. Attempting to cascade.") diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 52557862d..a4bc7f900 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -265,3 +265,31 @@ def graph_tables(dj_conn, graph_schema): yield graph_schema schema.drop(force=True) + + +@pytest.fixture(scope="module") +def graph_tables_many_to_one(graph_tables): + ParentNode = graph_tables["ParentNode"] + IntermediateNode = graph_tables["IntermediateNode"] + PkSkNode = graph_tables["PkSkNode"] + + pk_sk_keys = PkSkNode().fetch(as_dict=True)[-2:] + new_inserts = [ + { + "pk_sk_id": k["pk_sk_id"] + 3, + "intermediate_id": k["intermediate_id"] + 3, + "intermediate_attr": k["intermediate_id"] + 16, + "parent_id": k["intermediate_id"] - 1, + "parent_attr": k["intermediate_id"] + 11, + "other_id": k["other_id"], # No change + "pk_sk_attr": k["pk_sk_attr"] + 10, + } + for k in pk_sk_keys + ] + + insert_kwargs = {"ignore_extra_fields": True, "skip_duplicates": True} + ParentNode.insert(new_inserts, **insert_kwargs) + IntermediateNode.insert(new_inserts, **insert_kwargs) + PkSkNode.insert(new_inserts, **insert_kwargs) + + yield graph_tables diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 29d1f21ad..7d5257a36 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -99,6 +99,7 @@ def test_rg_root(restr_graph_root): ("pk_alias_attr > 18", 3, "pk pk alias"), ("sk_alias_attr > 19", 2, "sk sk alias"), ("merge_child_attr > 21", 2, "merge child down"), + ({"merge_child_attr": 21}, 1, "dict restr"), ], ) def test_restr_from_upstream(graph_tables, restr, expect_n, msg): @@ -114,8 +115,29 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg): ("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"), ("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"), ("MergeChild", "parent_attr > 18", 2, "merge child"), + ("MergeChild", {"parent_attr": 18}, 1, "dict restr"), ], ) def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): msg = "Error in `<<` for " + msg assert len(graph_tables[table]() << restr) == expect_n, msg + + +def test_restr_many_to_one(graph_tables_many_to_one): + PK = graph_tables_many_to_one["PkSkNode"]() + OP = graph_tables_many_to_one["OtherParentNode"]() + + msg_template = "Error in `%s` for many to one." + + assert len(PK << "other_attr > 14") == 4, msg_template % "<<" + assert len(PK << {"other_attr": 15}) == 2, msg_template % "<<" + assert len(OP >> "pk_sk_attr > 19") == 2, msg_template % ">>" + assert ( + len(OP >> [{"pk_sk_attr": 19}, {"pk_sk_attr": 20}]) == 2 + ), "Error accepting list of dicts for `>>` for many to one." + + +def test_restr_invalid(graph_tables): + PkNode = graph_tables["PkNode"]() + with pytest.raises(ValueError): + len(PkNode << set(["parent_attr > 15", "parent_attr < 20"]))