1+ from copy import copy
2+ from typing import Hashable , Union , Optional , Any
3+ import operator
4+ import deepmerge
5+
6+ def compare_tree_values (
7+ tree_a : dict | list ,
8+ tree_b : dict | list ,
9+ levels_a : list [Hashable | None ],
10+ levels_b : list [Hashable | None ],
11+ leaf_a : Hashable ,
12+ leaf_b : Hashable ,
13+ compare_func : Union [str , callable , None ],
14+ comparison_key : Optional [Hashable ]= None ,
15+ * args ,
16+ ** kwargs ,
17+ ) -> dict :
18+ """
19+ Returns a dictionary tree keyed according to
20+ 'tree_a': the first tree to compare
21+ 'tree_b': the second tree to compare
22+ 'levels_a': The levels to iterate through in order to access the leaf keys in
23+ 'leaves_a'. If a level is listed is None, then all keys at that level will
24+ be iterated over.
25+ 'levels_b': The levels to iterate through in order to access the leaf keys in
26+ 'leaves_b'. If a level is listed is None, then all keys at that level will
27+ be iterated over.
28+ 'leaf_a': The leaf in the tree_a to compare to the leaf in tree_b
29+ 'leaf_b': The leaf in the tree_b to compare to the leaf in tree_a
30+ 'compare_func': Either one of
31+ {'div', 'sub', 'add', 'mult', 'ge', 'le', 'lt', 'gt', 'eq', 'ne'}, None, or a
32+ user-supplied callable whose call signature takes the values of the individual
33+ elements of 'leaf_a' as the first param, the individual elements of 'leaf_b'
34+ as the second param. Optionally, args and kwargs can be passed and they
35+ will be passed on to the callable.
36+ If compare_func is None, then no function is called and the values for
37+ comparison are both entered into the returned dictionary but without
38+ a special comparison operation performed. If compare_func is None,
39+ 'comparison_key' is ignored.
40+ 'comparison_key': If provided, will serve as the key for the comparison value.
41+ If None, then the name of the comparison operator will used instead.
42+ """
43+ ops = {
44+ "div" : operator .truediv ,
45+ "sub" : operator .sub ,
46+ "add" : operator .add ,
47+ "mul" : operator .mul ,
48+ "ge" : operator .ge ,
49+ "le" : operator .le ,
50+ "lt" : operator .lt ,
51+ "gt" : operator .gt ,
52+ "eq" : operator .eq ,
53+ "ne" : operator .ne ,
54+ }
55+ env_acc = {}
56+ # If we are at the last branch...
57+ subtree_a = retrieve_leaves (tree_a , levels_a , leaf_a )
58+ subtree_b = retrieve_leaves (tree_b , levels_b , leaf_b )
59+
60+ branch_a = trim_branches (subtree_a , levels_a )
61+ branch_b = trim_branches (subtree_b , levels_b )
62+
63+ for trunk in branch_a .keys ():
64+ value_a = branch_a [trunk ]
65+ value_b = branch_b [trunk ]
66+ env_acc .setdefault (trunk , {})
67+ env_acc [trunk ].setdefault (leaf_a , value_a )
68+ env_acc [trunk ].setdefault (leaf_b , value_b )
69+ comparison_operator = ops .get (compare_func , compare_func )
70+ if comparison_operator is not None :
71+ compared_value = comparison_operator (value_a , value_b )
72+ if comparison_key is None :
73+ comparison_key = comparison_operator .__name__
74+ env_acc [trunk ].setdefault (comparison_key , compared_value )
75+ return env_acc
76+
77+
78+ def trim_branches (
79+ tree : dict | list ,
80+ levels : list [Hashable | None ],
81+ ):
82+ """
83+ Returns a copy of the 'tree' but with the branches in
84+ 'levels' trimmed off.
85+ """
86+ trimmed = tree .copy ()
87+ for i in range (len (levels )):
88+ leaf = levels .pop ()
89+ trimmed = retrieve_leaves (trimmed , levels , leaf = leaf )
90+ return trimmed
91+
92+
93+ def retrieve_leaves (
94+ tree : dict | list ,
95+ levels : list [Hashable | None ],
96+ leaf : list [Hashable ] | Hashable | None ,
97+ ) -> dict :
98+ """
99+ Envelopes the tree at the leaf node with 'agg_func'.
100+ """
101+ env_acc = {}
102+ key_error_msg = (
103+ "Key '{level}' does not exist at this level. Available keys: {keys}. "
104+ "Perhaps not all of your tree elements have the same keys. Try enveloping over trees "
105+ "that have the same branch structure and leaf names."
106+ )
107+ # If we are at the last branch...
108+ if not levels :
109+ if leaf is None :
110+ return tree
111+ if isinstance (leaf , list ):
112+ leaf_values = {}
113+ for leaf_elem in leaf :
114+ try :
115+ tree [leaf_elem ]
116+ except KeyError :
117+ raise KeyError (key_error_msg .format (level = leaf_elem , keys = list (tree .keys ())))
118+ leaf_values .update ({leaf_elem : tree [leaf_elem ]})
119+ else :
120+ try :
121+ tree [leaf ]
122+ except KeyError :
123+ raise KeyError (key_error_msg .format (level = leaf , keys = list (tree .keys ())))
124+ leaf_values = tree [leaf ]
125+ return leaf_values
126+ else :
127+ # Otherwise, pop the next level and dive into the tree on that branch
128+ level = levels [0 ]
129+ if level is not None :
130+ try :
131+ tree [level ]
132+ except KeyError :
133+ raise KeyError (key_error_msg .format (level = level , keys = list (tree .keys ())))
134+ env_acc .update ({level : retrieve_leaves (tree [level ], levels [1 :], leaf )})
135+ return env_acc
136+ else :
137+ # If None, then walk all branches of this node of the tree
138+ if isinstance (tree , list ):
139+ tree = {idx : leaf for idx , leaf in enumerate (tree )}
140+ for k , v in tree .items ():
141+ env_acc .update ({k : retrieve_leaves (v , levels [1 :], leaf )})
142+ return env_acc
143+
144+
145+ def extract_keys (
146+ object : dict [str , Any ],
147+ key_name : str ,
148+ include_startswith : Optional [str ] = None ,
149+ exclude_startswith : Optional [str ] = None ,
150+ ) -> list [dict [str , Any ]]:
151+ """
152+ Returns a list of dicts where each dict has a key of 'key_name'
153+ and a value of one of the keys of 'object'.
154+
155+ e.g.
156+ object = {"key1": value, "key2": value, "key3": value}
157+ key_name = "label"
158+
159+ extract_keys(object, key_name) # [{"label": "key1"}, {"label": "key2"}, {"label": "key3"}]
160+
161+ 'include_startswith': If provided, will only include keys that start with this string.
162+ 'exclude_startswith': If provided, will exclude all keys that start with this string.
163+
164+ If both 'include_startswith' and 'exclude_startswith' are provided, exclude is executed
165+ first.
166+ """
167+ shortlist = []
168+ for key in object .keys ():
169+ if exclude_startswith is not None and key .startswith (exclude_startswith ):
170+ continue
171+ else :
172+ shortlist .append (key )
173+
174+ acc = []
175+ for key in shortlist :
176+ if include_startswith is not None and key .startswith (include_startswith ):
177+ acc .append ({key_name : key })
178+ elif include_startswith is None :
179+ acc .append ({key_name : key })
180+
181+ return acc
182+
183+
184+
185+ def merge_trees (trees : list [dict [str , dict ]]) -> dict [str , dict ]:
186+ """
187+ Merges all of the tress (dictionaries) in 'result_trees'.
188+
189+ This is different than a typical dictionary merge (e.g. a | b)
190+ which will merge dictionaries with different keys but will over-
191+ write values if two keys are the same.
192+
193+ Instead, it crawls each branch of the tree and merges the data
194+ within each branch, no matter how deep the branches go.
195+ """
196+ acc = {}
197+ for result_tree in trees :
198+ acc = deepmerge .always_merger .merge (acc , result_tree )
199+ return acc
0 commit comments