diff --git a/bigtree/tree/helper.py b/bigtree/tree/helper.py index 2cd50451..ca20d682 100644 --- a/bigtree/tree/helper.py +++ b/bigtree/tree/helper.py @@ -242,6 +242,7 @@ def get_tree_diff( other_tree: node.Node, only_diff: bool = True, attr_list: List[str] = [], + fallback_sep: str = "/", ) -> node.Node: """Get difference of `tree` to `other_tree`, changes are relative to `tree`. @@ -333,11 +334,25 @@ def get_tree_diff( other_tree (Node): tree to be compared with only_diff (bool): indicator to show all nodes or only nodes that are different (+/-), defaults to True attr_list (List[str]): tree attributes to check for difference, defaults to empty list + fallback_sep (str): sep to fall back to if tree and other_tree has sep that clashes with symbols "+" / "-" / "~". + All node names in tree and other_tree should not contain this fallback_sep, defaults to "/" Returns: (Node) """ - other_tree.sep = tree.sep + if tree.sep != other_tree.sep: + raise ValueError("`sep` must be the same for tree and other_tree") + + forbidden_sep_symbols = ["+", "-", "~"] + if any( + forbidden_sep_symbol in tree.sep + for forbidden_sep_symbol in forbidden_sep_symbols + ): + tree = tree.copy() + other_tree = other_tree.copy() + tree.sep = fallback_sep + other_tree.sep = fallback_sep + name_col = "name" path_col = "PATH" indicator_col = "Exists" @@ -405,7 +420,9 @@ def get_tree_diff( ] data_both = data_both[[path_col]] if len(data_both): - tree_diff = construct.dataframe_to_tree(data_both, node_type=tree.__class__) + tree_diff = construct.dataframe_to_tree( + data_both, node_type=tree.__class__, sep=tree.sep + ) # Handle tree attribute difference if len(path_changes_deque): path_changes_list = sorted(path_changes_deque, reverse=True) diff --git a/tests/test_constants.py b/tests/test_constants.py index c2adaa59..20b3428d 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -169,6 +169,7 @@ class Constants: ERROR_NODE_PRUNE_NOT_FOUND = ( "Cannot find any node matching path_name ending with {prune_path}" ) + ERROR_NODE_TREE_DIFF_DIFF_SEP = "`sep` must be the same for tree and other_tree" # tree/modify ERROR_MODIFY_PARAM_TYPE = ( diff --git a/tests/tree/test_helper.py b/tests/tree/test_helper.py index 7f6096ce..0a1d5a85 100644 --- a/tests/tree/test_helper.py +++ b/tests/tree/test_helper.py @@ -236,10 +236,27 @@ def test_tree_diff(tree_node): assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff) @staticmethod - def test_tree_diff_diff_sep(tree_node): + def test_tree_diff_diff_sep_error(tree_node): other_tree_node = helper.prune_tree(tree_node, "a/c") - _ = node.Node("d", parent=other_tree_node) other_tree_node.sep = "-" + with pytest.raises(ValueError) as exc_info: + helper.get_tree_diff(tree_node, other_tree_node) + assert str(exc_info.value) == Constants.ERROR_NODE_TREE_DIFF_DIFF_SEP + + @staticmethod + def test_tree_diff_sep_clash_with_node_name_error(tree_node): + other_tree_node = helper.prune_tree(tree_node, "a/c") + _ = node.Node("/d", parent=other_tree_node) + with pytest.raises(exceptions.TreeError) as exc_info: + helper.get_tree_diff(tree_node, other_tree_node) + assert str(exc_info.value) == Constants.ERROR_NODE_NAME + + @staticmethod + def test_tree_diff_sep_clash_with_node_name(tree_node): + other_tree_node = helper.prune_tree(tree_node, "a/c") + _ = node.Node("/d", parent=other_tree_node) + tree_node.sep = "." + other_tree_node.sep = "." tree_only_diff = helper.get_tree_diff(tree_node, other_tree_node) expected_str = ( "a\n" @@ -248,10 +265,29 @@ def test_tree_diff_diff_sep(tree_node): "│ └── e (-)\n" "│ ├── g (-)\n" "│ └── h (-)\n" - "└── d (+)\n" + "└── /d (+)\n" ) assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff) + @staticmethod + def test_tree_diff_forbidden_sep(tree_node): + other_tree_node = helper.prune_tree(tree_node, "a/c") + _ = node.Node("d", parent=other_tree_node) + for symbol in [".", "-", "+", "~"]: + tree_node.sep = symbol + other_tree_node.sep = symbol + tree_only_diff = helper.get_tree_diff(tree_node, other_tree_node) + expected_str = ( + "a\n" + "├── b (-)\n" + "│ ├── d (-)\n" + "│ └── e (-)\n" + "│ ├── g (-)\n" + "│ └── h (-)\n" + "└── d (+)\n" + ) + assert_print_statement(export.print_tree, expected_str, tree=tree_only_diff) + @staticmethod def test_tree_diff_all_diff(tree_node): other_tree_node = helper.prune_tree(tree_node, "a/c")