-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree.py
182 lines (133 loc) · 5.25 KB
/
tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import jsons
import numpy as np
from treelib import Node, Tree
class Node:
def __init__(self, value=None, left=None, right=None, depth=0):
self.value = value
self.left = left
self.right = right
self.depth = depth
def is_leaf(self):
return self.right is None and self.left is None
@staticmethod
def from_dict(dict, depth=0):
if dict is None:
return None
node = Node(dict['value'], depth=depth)
node.right = Node.from_dict(dict['right'], depth=depth + 1)
node.left = Node.from_dict(dict['left'], depth=depth + 1)
return node
class BinaryDecisionTreeClassifier:
def __init__(self, filename=None):
self.root = None
if filename is not None:
self.load(filename)
def fit(self, X, y):
self.root = self._build_tree(X, y)
def predict(self, X):
return np.array([self._traverse_tree(X.iloc[i], self.root) for i in range(len(X))])
def print_tree(self):
if self.root is None:
return
treelibTree = Tree()
BinaryDecisionTreeClassifier.createTreelibTree(treelibTree, self.root)
treelibTree.show()
@staticmethod
def createTreelibTree(tree, node, parent=None):
if parent is None:
tree.create_node(node.value, node)
else:
tree.create_node(node.value, node, parent=parent)
if node.left is not None:
BinaryDecisionTreeClassifier.createTreelibTree(
tree, node.left, node)
if node.right is not None:
BinaryDecisionTreeClassifier.createTreelibTree(
tree, node.right, node)
def _drop_features_with_same_values(self, X):
for feature in X.columns:
if len(set(X[feature])) == 1:
X = X.drop(columns=feature)
return X
def _build_tree(self, X, y, depth=0):
X = self._drop_features_with_same_values(X)
n_samples, n_features = X.shape
if n_features == 0 and len(y) > 0:
return BinaryDecisionTreeClassifier.create_subtree_from_names_list(y, depth=depth)
if n_features == 0 and len(y) == 1:
return Node(y[0], depth=depth)
if len(y) == 0:
return Node()
if n_samples == 1:
return Node(y[0], depth=depth)
best_feature = self._choose_split_feature(X)
X_left, y_left, X_right, y_right = self._split_data(X, y, best_feature)
left_subtree = self._build_tree(X_left, y_left, depth + 1)
right_subtree = self._build_tree(X_right, y_right, depth + 1)
return Node(best_feature, left_subtree, right_subtree, depth=depth)
@staticmethod
def create_subtree_from_names_list(names_list, depth=0):
if len(names_list) == 0:
return Node()
if len(names_list) == 1:
return Node(names_list[0], depth=depth)
left_subtree = BinaryDecisionTreeClassifier.create_subtree_from_names_list(
names_list[1:], depth=depth+1)
right_subtree = BinaryDecisionTreeClassifier.create_subtree_from_names_list(
names_list[:1], depth=depth+1)
return Node(f"É {names_list[0]}?", left_subtree, right_subtree, depth=depth)
def _choose_split_feature(self, X):
best_feature = None
best_entropy = float('inf')
for feature in X.columns:
values = X[feature].values
entropy = abs(sum(values))
if entropy < best_entropy:
best_feature = feature
best_entropy = entropy
return best_feature
def _split_data(self, X, y, best_feature):
left_idx, = np.where(X[best_feature] == -1)
right_idx, = np.where(X[best_feature] == 1)
X = X.drop(columns=best_feature)
X_left = X.iloc[left_idx]
X_right = X.iloc[right_idx]
y_left = y[left_idx]
y_right = y[right_idx]
return X_left, y_left, X_right, y_right
def _traverse_tree(self, x, node):
if node.is_leaf():
return node.value
if x[node.value] == -1:
return self._traverse_tree(x, node.left)
else:
return self._traverse_tree(x, node.right)
def load(self, filename='tree.json'):
with open(filename, 'r') as f:
dict = jsons.loads(f.read())
self.root = Node.from_dict(dict['root'])
def dump(self, filename='tree.json'):
json = jsons.dumps(self)
with open(filename, 'w') as f:
f.write(str(json))
def drop(self):
self.root = None
@staticmethod
def get_max_depth_from_node(node):
if node.value is None:
return 0
if node.is_leaf():
return 1
return 1 + max(BinaryDecisionTreeClassifier.get_max_depth_from_node(node.left),
BinaryDecisionTreeClassifier.get_max_depth_from_node(node.right))
@staticmethod
def get_deepest_subtree(node):
if node.is_leaf():
return node
left = BinaryDecisionTreeClassifier.get_max_depth_from_node(node.left)
right = BinaryDecisionTreeClassifier.get_max_depth_from_node(
node.right)
if left > right:
return 'left', node.left
else:
return 'right', node.right