diff --git a/tree_diff/tree_ruleset_conversion.py b/tree_diff/tree_ruleset_conversion.py index aea8cdf..d3e7525 100644 --- a/tree_diff/tree_ruleset_conversion.py +++ b/tree_diff/tree_ruleset_conversion.py @@ -3,11 +3,13 @@ import river from collections import deque -visit = {} expected = {} def find_to_condition(visited, nodes): + """ + This function finds the conditions associated with a given node in a tree and stores them in a dictionary." + """ if nodes.is_root(): return None else: @@ -17,7 +19,6 @@ def find_to_condition(visited, nodes): index = i if index < 0: raise ValueError("Incorrect tree") - # print(nodes.label) if len(nodes.children) == 0: visited["{}:{}".format(node.parent, nodes)] = [ nodes.parent.conditions[index], @@ -31,6 +32,7 @@ def find_to_condition(visited, nodes): def traverse(tree, visited): + visit = {} for child in tree: visit = find_to_condition(visited, child) if len(child.children) > 0: @@ -39,6 +41,12 @@ def traverse(tree, visited): def link_dict_keys(d): + """ + This function takes the tree nodes and it's children to create a link between all of it's children. + It combines the conditions of the parent, child, and grandchild nodes into a single value. + Returns: A dictionary where keys are in the format "parent:child:grandchild" and + values are the sum of the conditions of the parent, child, and grandchild nodes. + """ linked_dict = {} result = {} for key, value in d.items(): @@ -66,9 +74,16 @@ def link_dict_keys(d): def tuple_tree_conversion(tree): visited = {} ruleset = [] - expected = link_dict_keys(traverse(tree.children, visited)) - for val in expected.values(): - ruleset.append(Rule(val[-1], val[0:-1])) + if len(tree.children) == 0: # Check if it's a root node + attr_name = 'root_node_tree' + visited["root"] = [f"{attr_name} <= 0", tree.label] + antecedent = visited["root"][0] + ruleset.append(Rule(visited['root'][1],[f"{antecedent}"])) # Force root node to have a (antecedent) Rule and label + return Ruleset(ruleset) + else: + expected = link_dict_keys(traverse(tree.children, visited)) # Traverse for each children and find linkage + for val in expected.values(): + ruleset.append(Rule(val[-1], val[0:-1])) return Ruleset(ruleset) @@ -108,29 +123,75 @@ def river_is_leaf(node): return node.n_leaves == 1 -def river_return_condition(node): +def river_return_condition(node,path,val_sum): if isinstance(node, river.tree.nodes.efdtc_nodes.EFDTNumericBinaryBranch): - return Condition(f"attr_{node.feature}", Operator.LE, node.threshold) + weight_value = {} + for elements in range(len(path)): + current = [] + for key, value in path[elements].stats.items(): + current.append(value) + weight_value[elements] = sum(current) # Store the current path status + all_values = list(weight_value.values()) + all_values.append(val_sum) + left,right = node.children + if left.total_weight in all_values: # Examine if either of children's weight matches with parent weight + operator = Operator.LE + elif right.total_weight in all_values: + operator = Operator.GT + return Condition(f"attr_{node.feature}", operator, node.threshold) + elif isinstance(node, tuple): # Multinomial feature = node[0].feature threshold = node[0]._r_mapping[node[1]] return Condition(f"attr_{feature}", Operator.EQ, threshold) + + elif isinstance(node, river.tree.nodes.efdtc_nodes.NumericBinaryBranch): + weight_value = {} + for elements in range(len(path)): + current = [] + keys_index = [] + for key, value in path[elements].stats.items(): + current.append(value) + keys_index.append(key) + weight_value[elements] = sum(current) # Store the current path status + all_values = list(weight_value.values()) + all_values.append(val_sum) + left,right = node.children + if left.total_weight in all_values: # Examine if either of children's weight matches with parent weight + operator = Operator.LE + elif right.total_weight in all_values: + operator = Operator.GT + else: + for elements in range(len(path)): + if left == path[elements]: + operator = Operator.LE + elif right == path[elements]: + operator = Operator.GT + return Condition(f"attr_{node.feature}", operator, node.threshold) else: raise ValueError(node) -def river_create_conditions(path_conds): - return [river_return_condition(c) for c in path_conds] +def river_create_conditions(path_conds,val_sum): + return [river_return_condition(c, path_conds, val_sum) for i, c in enumerate(path_conds)] def river_create_rule(path): a = path[-1].stats + weight = [] + for key,values in a.items(): # Storing the original weight of parent node + weight.append(values) + val_sum = sum(weight) m = (None, 0) for k, v in a.items(): if not m or m[1] < v: m = (k, v) label = m[0] - return Rule(conditions=river_create_conditions(path[0:-1]), label=f"{label}") + if len(path) == 1: # Check if the node is a root node + random_attr = "rand" + return Rule(conditions=[Condition(f"attr_{random_attr}", Operator.LE, 0)], label=f"{label}") + else: + return Rule(conditions=river_create_conditions(path[0:-1],val_sum), label=f"{label}") def river_extract_rules(tree, children, is_leaf):