Skip to content

Commit

Permalink
fix: get tree diff to handle sep that is different from default, cont…
Browse files Browse the repository at this point in the history
…aining forbidden symbols, and different for tree and other_tree
  • Loading branch information
kayjan committed Oct 16, 2024
1 parent fff2aa3 commit 5939fae
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
21 changes: 19 additions & 2 deletions bigtree/tree/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def get_tree_diff(
other_tree: node.Node,
only_diff: bool = True,
attr_list: List[str] = [],
fallback_sep: str = "/",
) -> node.Node:
"""Get difference of `tree` to `other_tree`, changes are relative to `tree`.
Expand Down Expand Up @@ -333,11 +334,25 @@ def get_tree_diff(
other_tree (Node): tree to be compared with
only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True
attr_list (List[str]): tree attributes to check for difference, defaults to empty list
fallback_sep (str): sep to fall back to if tree and other_tree has sep that clashes with symbols "+" / "-" / "~".
All node names in tree and other_tree should not contain this fallback_sep, defaults to "/"
Returns:
(Node)
"""
other_tree.sep = tree.sep
if tree.sep != other_tree.sep:
raise ValueError("`sep` must be the same for tree and other_tree")

forbidden_sep_symbols = ["+", "-", "~"]
if any(
forbidden_sep_symbol in tree.sep
for forbidden_sep_symbol in forbidden_sep_symbols
):
tree = tree.copy()
other_tree = other_tree.copy()
tree.sep = fallback_sep
other_tree.sep = fallback_sep

name_col = "name"
path_col = "PATH"
indicator_col = "Exists"
Expand Down Expand Up @@ -405,7 +420,9 @@ def get_tree_diff(
]
data_both = data_both[[path_col]]
if len(data_both):
tree_diff = construct.dataframe_to_tree(data_both, node_type=tree.__class__)
tree_diff = construct.dataframe_to_tree(
data_both, node_type=tree.__class__, sep=tree.sep
)
# Handle tree attribute difference
if len(path_changes_deque):
path_changes_list = sorted(path_changes_deque, reverse=True)
Expand Down
1 change: 1 addition & 0 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class Constants:
ERROR_NODE_PRUNE_NOT_FOUND = (
"Cannot find any node matching path_name ending with {prune_path}"
)
ERROR_NODE_TREE_DIFF_DIFF_SEP = "`sep` must be the same for tree and other_tree"

# tree/modify
ERROR_MODIFY_PARAM_TYPE = (
Expand Down
42 changes: 39 additions & 3 deletions tests/tree/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,27 @@ def test_tree_diff(tree_node):
assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff)

@staticmethod
def test_tree_diff_diff_sep(tree_node):
def test_tree_diff_diff_sep_error(tree_node):
other_tree_node = helper.prune_tree(tree_node, "a/c")
_ = node.Node("d", parent=other_tree_node)
other_tree_node.sep = "-"
with pytest.raises(ValueError) as exc_info:
helper.get_tree_diff(tree_node, other_tree_node)
assert str(exc_info.value) == Constants.ERROR_NODE_TREE_DIFF_DIFF_SEP

@staticmethod
def test_tree_diff_sep_clash_with_node_name_error(tree_node):
other_tree_node = helper.prune_tree(tree_node, "a/c")
_ = node.Node("/d", parent=other_tree_node)
with pytest.raises(exceptions.TreeError) as exc_info:
helper.get_tree_diff(tree_node, other_tree_node)
assert str(exc_info.value) == Constants.ERROR_NODE_NAME

@staticmethod
def test_tree_diff_sep_clash_with_node_name(tree_node):
other_tree_node = helper.prune_tree(tree_node, "a/c")
_ = node.Node("/d", parent=other_tree_node)
tree_node.sep = "."
other_tree_node.sep = "."
tree_only_diff = helper.get_tree_diff(tree_node, other_tree_node)
expected_str = (
"a\n"
Expand All @@ -248,10 +265,29 @@ def test_tree_diff_diff_sep(tree_node):
"│ └── e (-)\n"
"│ ├── g (-)\n"
"│ └── h (-)\n"
"└── d (+)\n"
"└── /d (+)\n"
)
assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff)

@staticmethod
def test_tree_diff_forbidden_sep(tree_node):
other_tree_node = helper.prune_tree(tree_node, "a/c")
_ = node.Node("d", parent=other_tree_node)
for symbol in [".", "-", "+", "~"]:
tree_node.sep = symbol
other_tree_node.sep = symbol
tree_only_diff = helper.get_tree_diff(tree_node, other_tree_node)
expected_str = (
"a\n"
"├── b (-)\n"
"│ ├── d (-)\n"
"│ └── e (-)\n"
"│ ├── g (-)\n"
"│ └── h (-)\n"
"└── d (+)\n"
)
assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff)

@staticmethod
def test_tree_diff_all_diff(tree_node):
other_tree_node = helper.prune_tree(tree_node, "a/c")
Expand Down

0 comments on commit 5939fae

Please sign in to comment.