Skip to content

Commit 799f56d

Browse files
committed
fix: fine-tune comparison func
1 parent 99047d7 commit 799f56d

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

src/jsonchain/tree.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ def compare_tree_values(
88
tree_b: dict | list,
99
levels_a: list[Hashable | None],
1010
levels_b: list[Hashable | None],
11-
leaf_a: Union[Hashable, list[Hashable]],
12-
leaf_b: Union[Hashable, list[Hashable]],
13-
compare_func: Union[str, callable],
14-
compared_key: Optional[Hashable]= None,
11+
leaf_a: Hashable,
12+
leaf_b: Hashable,
13+
compare_func: Union[str, callable, None],
14+
comparison_key: Optional[Hashable]= None,
1515
*args,
1616
**kwargs,
1717
) -> dict:
@@ -25,14 +25,20 @@ def compare_tree_values(
2525
'levels_b': The levels to iterate through in order to access the leaf keys in
2626
'leaves_b'. If a level is listed is None, then all keys at that level will
2727
be iterated over.
28-
'leaves_a': a list of leaf keys to compare. Must be same length as 'leaves_b'.
29-
'leaves_b': a list of leaf keys to compare. Must be same length as 'leaves_a'.
28+
'leaf_a': The leaf in the tree_a to compare to the leaf in tree_b
29+
'leaf_b': The leaf in the tree_b to compare to the leaf in tree_a
3030
'compare_func': Either one of
31-
{'div', 'sub', 'add', 'mult', 'ge', 'le', 'lt', 'gt', 'eq', 'ne'} or a
32-
user-supplied callable whos call signature takes the values of the individul
33-
elements of 'leaves_a' as the first param, the individual elements of 'leaves_b'
31+
{'div', 'sub', 'add', 'mult', 'ge', 'le', 'lt', 'gt', 'eq', 'ne'}, None, or a
32+
user-supplied callable whose call signature takes the values of the individual
33+
elements of 'leaf_a' as the first param, the individual elements of 'leaf_b'
3434
as the second param. Optionally, args and kwargs can be passed and they
3535
will be passed on to the callable.
36+
If compare_func is None, then no function is called and the values for
37+
comparison are both entered into the returned dictionary but without
38+
a special comparison operation performed. If compare_func is None,
39+
'comparison_key' is ignored.
40+
'comparison_key': If provided, will serve as the key for the comparison value.
41+
If None, then the name of the comparison operator will used instead.
3642
"""
3743
ops = {
3844
"div": operator.truediv,
@@ -57,14 +63,15 @@ def compare_tree_values(
5763
for trunk in branch_a.keys():
5864
value_a = branch_a[trunk]
5965
value_b = branch_b[trunk]
60-
comparison_operator = ops.get(compare_func, compare_func)
61-
compared_value = comparison_operator(value_a, value_b)
6266
env_acc.setdefault(trunk, {})
6367
env_acc[trunk].setdefault(leaf_a, value_a)
6468
env_acc[trunk].setdefault(leaf_b, value_b)
65-
if compared_key is None:
66-
compared_key = str(compare_func)
67-
env_acc[trunk].setdefault(compared_key, compared_value)
69+
comparison_operator = ops.get(compare_func, compare_func)
70+
if comparison_operator is not None:
71+
compared_value = comparison_operator(value_a, value_b)
72+
if comparison_key is None:
73+
comparison_key = comparison_operator.__name__
74+
env_acc[trunk].setdefault(comparison_key, compared_value)
6875
return env_acc
6976

7077

0 commit comments

Comments
 (0)