@@ -19,22 +19,30 @@ def envelope_tree(
1919 # If we are at the last branch...
2020 if not levels :
2121 if isinstance (tree , list ):
22- tree = {idx : leaf for idx , leaf in enumerate (tree )}
23- leaf_acc = {}
24- for leaf_key , leaf_value in tree .items ():
25- if leaf is not None and leaf_key == leaf :
26- leaf_acc .update ({leaf_key : leaf_value })
27- else :
28- leaf_acc = tree
22+ orig_tree = tree
23+ tree = {idx : value for idx , value in enumerate (tree )}
24+ elif isinstance (tree , dict ):
25+ orig_tree = tree
26+ tree = {idx : branch [leaf ] for idx , branch in enumerate (tree .values ())}
27+ trace_tree = {idx : key for idx , key in enumerate (orig_tree .keys ())}
28+
29+ # leaf_acc = {}
30+ # for leaf_key, leaf_value in tree.items():
31+ # if leaf is not None and leaf_key == leaf:
32+ # leaf_acc.update({leaf_key: leaf_value})
33+ # else:
34+ # leaf_acc = tree
35+
2936 # ...create a dict of the enveloped value and the key
3037 # that it belongs to and return it.
31- env_values = list (leaf_acc .values ())
38+ # env_values = list(leaf_acc.values())
39+ env_values = list (tree .values ())
3240 env_value = agg_func (env_values )
33- env_keys = list (leaf_acc .keys ())
41+ # env_keys = list(leaf_acc.keys())
3442 try :
35- env_key = env_keys [env_values .index (env_value )]
43+ env_key = trace_tree [env_values .index (env_value )]
3644 except ValueError : # The value was transformed, likely due to abs()
37- env_key = env_keys [env_values .index (- 1 * env_value )]
45+ env_key = trace_tree [env_values .index (- 1 * env_value )]
3846 env_acc .update ({"key" : env_key , "value" : env_value })
3947 if with_trace :
4048 return env_acc
0 commit comments