-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixin method to restrict table by keys of upstream tables #930
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from spyglass.utils.dj_chains import TableChain, TableChains | ||
from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table | ||
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK | ||
from spyglass.utils.dj_merge_tables import Merge | ||
from spyglass.utils.logging import logger | ||
|
||
try: | ||
|
@@ -535,3 +536,91 @@ def super_delete(self, *args, **kwargs): | |
logger.warning("!! Using super_delete. Bypassing cautious_delete !!") | ||
self._log_use(start=time(), super_delete=True) | ||
super().delete(*args, **kwargs) | ||
|
||
from spyglass.utils.dj_merge_tables import Merge | ||
|
||
def restrict_from_upstream(self, key, **kwargs): | ||
"""Recursive function to restrict a table based on secondary keys of upstream tables""" | ||
return restrict_from_upstream(self, key, **kwargs) | ||
|
||
|
||
def restrict_from_upstream(table, key, max_recursion=3): | ||
"""Recursive function to restrict a table based on secondary keys of upstream tables""" | ||
print(f"table: {table.full_table_name}, key: {key}") | ||
# Tables not to recurse through because too big or central | ||
blacklist = [ | ||
"`common_nwbfile`.`analysis_nwbfile`", | ||
] | ||
|
||
# Case: MERGE table | ||
if (table := table & key) and max_recursion: | ||
if isinstance(table, Merge): | ||
parts = table.parts(as_objects=True) | ||
restricted_parts = [ | ||
restrict_from_upstream(part, key, max_recursion - 1) | ||
for part in parts | ||
] | ||
# only keep entries from parts that got restricted | ||
restricted_parts = [ | ||
r_part.proj("merge_id") | ||
for r_part, part in zip(restricted_parts, parts) | ||
if ( | ||
not len(r_part) == len(part) | ||
or check_complete_restrict(r_part, key) | ||
) | ||
] | ||
# return the merge of the restricted parts | ||
merge_keys = [] | ||
for r_part in restricted_parts: | ||
merge_keys.extend(r_part.fetch("merge_id", as_dict=True)) | ||
return table & merge_keys | ||
|
||
# Case: regular table | ||
upstream_tables = table.parents(as_objects=True) | ||
# prevent a loop where call Merge master table from part | ||
upstream_tables = [ | ||
parent | ||
for parent in upstream_tables | ||
if not ( | ||
isinstance(parent, Merge) | ||
and table.full_table_name in parent.parts() | ||
) | ||
and (parent.full_table_name not in blacklist) | ||
] | ||
for parent in upstream_tables: | ||
print(parent.full_table_name) | ||
print(len(parent)) | ||
r_parent = restrict_from_upstream(parent, key, max_recursion - 1) | ||
if len(r_parent) == len(parent): | ||
continue # skip joins with uninformative tables | ||
table = safe_join(table, r_parent) | ||
if check_complete_restrict(table, key) or not table: | ||
print(len(table)) | ||
break | ||
return table | ||
|
||
|
||
def check_complete_restrict(table, key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dj has a builtin for this called 'assert_join_compatibility' - if adopted, it would need to be wrapped in a try/except and call it using the key as a QueryExpression, but I'm inclined to let them handle this kind of thing rather than add more for us to maintain. |
||
"""Checks all keys in a restriction dictionary are used in a table""" | ||
if all([k in table.heading.names for k in key.keys()]): | ||
print("FOUND") | ||
return all([k in table.heading.names for k in key.keys()]) | ||
|
||
|
||
# Utility Function | ||
def safe_join(table_1, table_2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the permissive join operator could help you here |
||
"""enables joining of two tables with overlapping secondary keys""" | ||
secondary_1 = [ | ||
name | ||
for name in table_1.heading.names | ||
if name not in table_1.primary_key | ||
] | ||
secondary_2 = [ | ||
name | ||
for name in table_2.heading.names | ||
if name not in table_2.primary_key | ||
] | ||
overlap = [name for name in secondary_1 if name in secondary_2] | ||
return table_1 * table_2.proj( | ||
*[name for name in table_2.heading.names if name not in overlap] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking I could fold this into either
dj_chains
ordj_graph
to make use of how restr are cascaded there. it might cutdown on the footprint for this feature