Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long distance restrictions #949

Merged
merged 23 commits into from
May 10, 2024
Merged
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP: ABC for RestrGraph
CBroz1 committed Apr 25, 2024
commit f1aa8727e0f73adf142512c0cf4825fd30ae5267
186 changes: 103 additions & 83 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
NOTE: read `ft` as FreeTable and `restr` as restriction.
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Union

from datajoint import FreeTable
@@ -15,71 +16,38 @@
from spyglass.utils.dj_helper_fn import unique_dicts


class RestrGraph:
def __init__(
self,
seed_table: Table,
table_name: str = None,
restriction: str = None,
leaves: List[Dict[str, str]] = None,
cascade: bool = False,
verbose: bool = False,
**kwargs,
):
"""Use graph to cascade restrictions up from leaves to all ancestors.
class AbstractGraph(ABC):
def __init__(self, seed_table: Table, verbose: bool = False, **kwargs):
"""Abstract class for graph traversal and restriction application.

Parameters
----------
seed_table : Table
Table to use to establish connection and graph
table_name : str, optional
Table name of single leaf, default None
restriction : str, optional
Restriction to apply to leaf. default None
leaves : Dict[str, str], optional
List of dictionaries with keys table_name and restriction. One
entry per leaf node. Default None.
cascade : bool, optional
Whether to cascade restrictions up the graph on initialization.
Default False
verbose : bool, optional
Whether to print verbose output. Default False
"""

self.connection = seed_table.connection
self.graph = seed_table.connection.dependencies
self.graph.load()

self.verbose = verbose
self.cascaded = False
self.ancestors = set()
self.visited = set()
self.leaves = set()
self.analysis_pk = AnalysisNwbfile().primary_key

if table_name and restriction:
self.add_leaf(table_name, restriction)
if leaves:
self.add_leaves(leaves, show_progress=verbose)

if cascade:
self.cascade()

def __repr__(self):
l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else ""
processed = "Cascaded" if self.cascaded else "Uncascaded"
return f"{processed} RestrictionGraph(\n\t{l_str})"
self.to_visit = set()
self.cascaded = False

@property
def all_ft(self):
"""Get restricted FreeTables from all visited nodes."""
self.cascade()
return [self._get_ft(table, with_restr=True) for table in self.visited]
def _log_truncate(self, log_str, max_len=80):
"""Truncate log lines to max_len and print if verbose."""
if not self.verbose:
return
logger.info(
log_str[:max_len] + "..." if len(log_str) > max_len else log_str
)

@property
def leaf_ft(self):
"""Get restricted FreeTables from graph leaves."""
return [self._get_ft(table, with_restr=True) for table in self.leaves]
@abstractmethod
def cascade(self):
"""Cascade restrictions through graph."""
raise NotImplementedError("Child class mut implement `cascade` method")

def _get_node(self, table):
"""Get node from graph."""
@@ -95,25 +63,11 @@ def _set_node(self, table, attr="ft", value=None):
_ = self._get_node(table) # Ensure node exists
self.graph.nodes[table][attr] = value

def _get_ft(self, table, with_restr=False):
"""Get FreeTable from graph node. If one doesn't exist, create it."""
table = table if isinstance(table, str) else table.full_table_name
restr = self._get_restr(table) if with_restr else True
if ft := self._get_node(table).get("ft"):
return ft & restr
ft = FreeTable(self.connection, table)
self._set_node(table, "ft", ft)
return ft & restr

def _get_restr(self, table):
"""Get restriction from graph node."""
table = table if isinstance(table, str) else table.full_table_name
return self._get_node(table).get("restr", "False")

def _get_files(self, table):
"""Get analysis files from graph node."""
return self._get_node(table).get("files", [])

def _set_restr(self, table, restriction):
"""Add restriction to graph node. If one exists, merge with new."""
ft = self._get_ft(table)
@@ -122,7 +76,6 @@ def _set_restr(self, table, restriction):
if not isinstance(restriction, str)
else restriction
)
# orig_restr = restriction
if existing := self._get_restr(table):
if existing == restriction:
return
@@ -135,10 +88,26 @@ def _set_restr(self, table, restriction):

self._set_node(table, "restr", restriction)

def _get_ft(self, table, with_restr=False):
"""Get FreeTable from graph node. If one doesn't exist, create it."""
table = table if isinstance(table, str) else table.full_table_name
restr = self._get_restr(table) if with_restr else True
if ft := self._get_node(table).get("ft"):
return ft & restr
ft = FreeTable(self.connection, table)
self._set_node(table, "ft", ft)
return ft & restr

@property
def all_ft(self):
"""Get restricted FreeTables from all visited nodes."""
self.cascade()
return [self._get_ft(table, with_restr=True) for table in self.visited]

def get_restr_ft(self, table: Union[int, str]):
"""Get restricted FreeTable from graph node.

Currently used. May be useful for debugging.
Currently used for testing.

Parameters
----------
@@ -149,14 +118,6 @@ def get_restr_ft(self, table: Union[int, str]):
table = list(self.visited)[table]
return self._get_ft(table, with_restr=True)

def _log_truncate(self, log_str, max_len=80):
"""Truncate log lines to max_len and print if verbose."""
if not self.verbose:
return
logger.info(
log_str[:max_len] + "..." if len(log_str) > max_len else log_str
)

def _child_to_parent(
self,
child,
@@ -213,6 +174,75 @@ def _child_to_parent(

return ret

@property
def as_dict(self) -> List[Dict[str, str]]:
"""Return as a list of dictionaries of table_name: restriction"""
self.cascade()
return [
{"table_name": table, "restriction": self._get_restr(table)}
for table in self.visited
if self._get_restr(table)
]


class RestrGraph(AbstractGraph):
def __init__(
self,
seed_table: Table,
table_name: str = None,
restriction: str = None,
leaves: List[Dict[str, str]] = None,
cascade: bool = False,
verbose: bool = False,
**kwargs,
):
"""Use graph to cascade restrictions up from leaves to all ancestors.

Parameters
----------
seed_table : Table
Table to use to establish connection and graph
table_name : str, optional
Table name of single leaf, default None
restriction : str, optional
Restriction to apply to leaf. default None
leaves : Dict[str, str], optional
List of dictionaries with keys table_name and restriction. One
entry per leaf node. Default None.
cascade : bool, optional
Whether to cascade restrictions up the graph on initialization.
Default False
verbose : bool, optional
Whether to print verbose output. Default False
"""
super().__init__(seed_table, verbose=verbose)

self.ancestors = set()
self.leaves = set()
self.analysis_pk = AnalysisNwbfile().primary_key

if table_name and restriction:
self.add_leaf(table_name, restriction)
if leaves:
self.add_leaves(leaves, show_progress=verbose)

if cascade:
self.cascade()

def __repr__(self):
l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else ""
processed = "Cascaded" if self.cascaded else "Uncascaded"
return f"{processed} RestrictionGraph(\n\t{l_str})"

@property
def leaf_ft(self):
"""Get restricted FreeTables from graph leaves."""
return [self._get_ft(table, with_restr=True) for table in self.leaves]

def _get_files(self, table):
"""Get analysis files from graph node."""
return self._get_node(table).get("files", [])

def cascade_files(self):
"""Set node attribute for analysis files."""
for table in self.visited:
@@ -341,16 +371,6 @@ def add_leaves(
self.cascade()
self.cascade_files()

@property
def as_dict(self) -> List[Dict[str, str]]:
"""Return as a list of dictionaries of table_name: restriction"""
self.cascade()
return [
{"table_name": table, "restriction": self._get_restr(table)}
for table in self.ancestors
if self._get_restr(table)
]

@property
def file_dict(self) -> Dict[str, List[str]]:
"""Return dictionary of analysis files from all visited nodes.