-
Notifications
You must be signed in to change notification settings - Fork 219
/
simple_decision_tree.py
167 lines (133 loc) · 7.06 KB
/
simple_decision_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import print_function, division
''' A class to implement a simple decision tree (based on ID3)
'''
__author__ = 'Joe McCarthy'
__email__ = 'joe@interrelativity.com'
from collections import Counter
from pprint import pprint
from simple_ml import majority_value, choose_best_attribute_index, split_instances
class SimpleDecisionTree:
_tree = {} # this instance variable becomes accessible to class methods via self._tree
def __init__(self):
# this is where we would initialize any parameters to the SimpleDecisionTree
pass
def fit(self,
instances,
candidate_attribute_indexes=None,
target_attribute_index=0,
default_class=None,
trace=0):
'''
Build a decision tree that best fits the data in instances.
The target_attribute_index defaults to 0 (zero).
The candidate_attribute_indexes defaults to all other index values.
The tree is constructed by recursively selecting & splitting instances based on
the highest information_gain of the candidate_attribute_indexes.
The class label is found in target_attribute_index.
The default_class is the majority value for that branch of the tree.
A positive trace value will print trace information during tree construction.
Derived from the simplified ID3 algorithm presented in
Building Decision Trees in Python by Christopher Roach,
http://www.onlamp.com/pub/a/python/2006/02/09/ai_decision_trees.html?page=3
'''
if not candidate_attribute_indexes:
candidate_attribute_indexes = [i
for i in range(len(instances[0]))
if i != target_attribute_index]
self._tree = self._create_tree(instances,
candidate_attribute_indexes,
target_attribute_index,
default_class)
def _create_tree(self,
instances,
candidate_attribute_indexes,
target_attribute_index=0,
default_class=None,
trace=0):
class_labels_and_counts = Counter([instance[target_attribute_index]
for instance in instances])
# If the dataset is empty or the candidate attributes list is empty,
# return the default class label
if not instances or not candidate_attribute_indexes:
if trace:
print('{}Using default class {}'.format('< ' * trace, default_class))
return default_class
# If all the instances have the same class label, return that class label
elif len(class_labels_and_counts) == 1:
class_label = class_labels_and_counts.most_common(1)[0][0]
if trace:
print('{}All {} instances have label {}'.format(
'< ' * trace, len(instances), class_label))
return class_label
# Otherwise, create a new subtree and add it to the tree
else:
default_class = majority_value(instances, target_attribute_index)
# Choose the next best attribute index to best classify the instances
best_index = choose_best_attribute_index(instances,
candidate_attribute_indexes,
target_attribute_index)
if trace:
print('{}Creating tree node for attribute index {}'.format(
'> ' * trace, best_index))
# Create a new decision tree node with the best attribute index
# and an empty dictionary object (for now)
tree = {best_index:{}}
# Create a new decision tree sub-node (branch)
# for each of the values in the best attribute field
partitions = split_instances(instances, best_index)
# Remove that attribute from the set of candidates for further splits
remaining_candidate_attribute_indexes = [i
for i in candidate_attribute_indexes
if i != best_index]
for attribute_value in partitions:
if trace:
print('{}Creating subtree for value {} ({}, {}, {}, {})'.format(
'> ' * trace,
attribute_value,
len(partitions[attribute_value]),
len(remaining_candidate_attribute_indexes),
target_attribute_index,
default_class))
# Create a subtree for each value of the the best attribute
subtree = self._create_tree(
partitions[attribute_value],
remaining_candidate_attribute_indexes,
target_attribute_index,
default_class)
# Add the new subtree to the empty dictionary object
# in the new tree/node created above
tree[best_index][attribute_value] = subtree
return tree
def predict(self, instances, default_class=None):
'''Return the predicted class label(s) of instance(s)'''
if not isinstance(instances, list):
return self._predict(self._tree, instance, default_class)
else:
return [self._predict(self._tree, instance, default_class)
for instance in instances]
# a method intended to be "protected" that can implement the recursive algorithm to classify an instance given a tree
def _predict(self, tree, instance, default_class=None):
if not tree:
return default_class
if not isinstance(tree, dict):
return tree
attribute_index = list(tree.keys())[0] # using list(dict.keys()) for Py3 compatibiity
attribute_values = list(tree.values())[0]
instance_attribute_value = instance[attribute_index]
if instance_attribute_value not in attribute_values:
return default_class
return self._predict(attribute_values[instance_attribute_value],
instance,
default_class)
def classification_accuracy(self, instances, default_class=None):
'''Return a tuple with
the number of correctly classified instances,
the number of incorrectly classified instances,
the proportion of instances that were correctly classified
'''
predicted_labels = self.predict(instances, default_class)
actual_labels = [x[0] for x in instances]
counts = Counter([x == y for x, y in zip(predicted_labels, actual_labels)])
return counts[True] / len(instances), counts[True], counts[False]
def pprint(self):
pprint(self._tree)