diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 082116bf6..529422739 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -15,6 +15,7 @@ from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.dj_merge_tables import Merge from spyglass.utils.logging import logger try: @@ -535,3 +536,91 @@ def super_delete(self, *args, **kwargs): logger.warning("!! Using super_delete. Bypassing cautious_delete !!") self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) + + from spyglass.utils.dj_merge_tables import Merge + + def restrict_from_upstream(self, key, **kwargs): + """Recursive function to restrict a table based on secondary keys of upstream tables""" + return restrict_from_upstream(self, key, **kwargs) + + +def restrict_from_upstream(table, key, max_recursion=3): + """Recursive function to restrict a table based on secondary keys of upstream tables""" + print(f"table: {table.full_table_name}, key: {key}") + # Tables not to recurse through because too big or central + blacklist = [ + "`common_nwbfile`.`analysis_nwbfile`", + ] + + # Case: MERGE table + if (table := table & key) and max_recursion: + if isinstance(table, Merge): + parts = table.parts(as_objects=True) + restricted_parts = [ + restrict_from_upstream(part, key, max_recursion - 1) + for part in parts + ] + # only keep entries from parts that got restricted + restricted_parts = [ + r_part.proj("merge_id") + for r_part, part in zip(restricted_parts, parts) + if ( + not len(r_part) == len(part) + or check_complete_restrict(r_part, key) + ) + ] + # return the merge of the restricted parts + merge_keys = [] + for r_part in restricted_parts: + merge_keys.extend(r_part.fetch("merge_id", as_dict=True)) + return table & merge_keys + + # Case: regular table + upstream_tables = table.parents(as_objects=True) + # prevent a loop where call Merge master table from part + upstream_tables = [ + parent + for parent in upstream_tables + if not ( + isinstance(parent, Merge) + and table.full_table_name in parent.parts() + ) + and (parent.full_table_name not in blacklist) + ] + for parent in upstream_tables: + print(parent.full_table_name) + print(len(parent)) + r_parent = restrict_from_upstream(parent, key, max_recursion - 1) + if len(r_parent) == len(parent): + continue # skip joins with uninformative tables + table = safe_join(table, r_parent) + if check_complete_restrict(table, key) or not table: + print(len(table)) + break + return table + + +def check_complete_restrict(table, key): + """Checks all keys in a restriction dictionary are used in a table""" + if all([k in table.heading.names for k in key.keys()]): + print("FOUND") + return all([k in table.heading.names for k in key.keys()]) + + +# Utility Function +def safe_join(table_1, table_2): + """enables joining of two tables with overlapping secondary keys""" + secondary_1 = [ + name + for name in table_1.heading.names + if name not in table_1.primary_key + ] + secondary_2 = [ + name + for name in table_2.heading.names + if name not in table_2.primary_key + ] + overlap = [name for name in secondary_1 if name in secondary_2] + return table_1 * table_2.proj( + *[name for name in table_2.heading.names if name not in overlap] + )