Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixin method to restrict table by keys of upstream tables #930

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from spyglass.utils.dj_chains import TableChain, TableChains
from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.dj_merge_tables import Merge
from spyglass.utils.logging import logger

try:
Expand Down Expand Up @@ -535,3 +536,91 @@ def super_delete(self, *args, **kwargs):
logger.warning("!! Using super_delete. Bypassing cautious_delete !!")
self._log_use(start=time(), super_delete=True)
super().delete(*args, **kwargs)

from spyglass.utils.dj_merge_tables import Merge

def restrict_from_upstream(self, key, **kwargs):
"""Recursive function to restrict a table based on secondary keys of upstream tables"""
return restrict_from_upstream(self, key, **kwargs)


def restrict_from_upstream(table, key, max_recursion=3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking I could fold this into either dj_chains or dj_graph to make use of how restr are cascaded there. it might cutdown on the footprint for this feature

"""Recursive function to restrict a table based on secondary keys of upstream tables"""
print(f"table: {table.full_table_name}, key: {key}")
# Tables not to recurse through because too big or central
blacklist = [
"`common_nwbfile`.`analysis_nwbfile`",
]

# Case: MERGE table
if (table := table & key) and max_recursion:
if isinstance(table, Merge):
parts = table.parts(as_objects=True)
restricted_parts = [
restrict_from_upstream(part, key, max_recursion - 1)
for part in parts
]
# only keep entries from parts that got restricted
restricted_parts = [
r_part.proj("merge_id")
for r_part, part in zip(restricted_parts, parts)
if (
not len(r_part) == len(part)
or check_complete_restrict(r_part, key)
)
]
# return the merge of the restricted parts
merge_keys = []
for r_part in restricted_parts:
merge_keys.extend(r_part.fetch("merge_id", as_dict=True))
return table & merge_keys

# Case: regular table
upstream_tables = table.parents(as_objects=True)
# prevent a loop where call Merge master table from part
upstream_tables = [
parent
for parent in upstream_tables
if not (
isinstance(parent, Merge)
and table.full_table_name in parent.parts()
)
and (parent.full_table_name not in blacklist)
]
for parent in upstream_tables:
print(parent.full_table_name)
print(len(parent))
r_parent = restrict_from_upstream(parent, key, max_recursion - 1)
if len(r_parent) == len(parent):
continue # skip joins with uninformative tables
table = safe_join(table, r_parent)
if check_complete_restrict(table, key) or not table:
print(len(table))
break
return table


def check_complete_restrict(table, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dj has a builtin for this called 'assert_join_compatibility' - if adopted, it would need to be wrapped in a try/except and call it using the key as a QueryExpression, but I'm inclined to let them handle this kind of thing rather than add more for us to maintain.

"""Checks all keys in a restriction dictionary are used in a table"""
if all([k in table.heading.names for k in key.keys()]):
print("FOUND")
return all([k in table.heading.names for k in key.keys()])


# Utility Function
def safe_join(table_1, table_2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the permissive join operator could help you here a @ b - they talk about it in the docstring here

"""enables joining of two tables with overlapping secondary keys"""
secondary_1 = [
name
for name in table_1.heading.names
if name not in table_1.primary_key
]
secondary_2 = [
name
for name in table_2.heading.names
if name not in table_2.primary_key
]
overlap = [name for name in secondary_1 if name in secondary_2]
return table_1 * table_2.proj(
*[name for name in table_2.heading.names if name not in overlap]
)
Loading