-
Notifications
You must be signed in to change notification settings - Fork 0
/
ID3.py
115 lines (104 loc) · 3.33 KB
/
ID3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from math import log
from readData import *
import logging
import sys
def label(feature):
return feature[len(feature) - 1]
def elog(num):
try:
return log(num)
except:
return 0
def cond_entropy(feature_dist):
total = sum(feature_dist.itervalues())
if total == 0:
return sys.float_info.max
condent = 0
for count in feature_dist.itervalues():
probability = (float(count))/total
condent -= probability*elog(probability)
#if condent < 0:
#logging.debug(feature_dist)
#sys.exit(0)
return condent
def find_split(matrix):
min_ent = sys.float_info.max
split_feat = 0
split_thresh = float(0)
prev = 0
curr = 0
feat_dist_init = {}
rem_dist_init = {}
point_count = len(matrix)
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
for row in matrix:
logging.debug('storing points in rem dict')
try:
rem_dist_init[label(row)] += 1
except:
rem_dist_init[label(row)] = 1
for feat in range(0, len(matrix[0]) - 2):
logging.debug('analyzing new feat')
feat_dist = feat_dist_init
rem_dist = rem_dist_init
matrix = sorted(matrix, key = lambda m_entry: m_entry[feat])
logging.debug('pts sorted')
for count, point in enumerate(matrix):
prev = curr
curr = point[feat]
if curr != prev and count:
logging.debug('calc info gain')
ent = (float(count)/point_count)*cond_entropy(feat_dist) + \
(float(count)/point_count)*cond_entropy(rem_dist)
logging.debug('ent is ')
logging.debug(ent)
if (ent < min_ent):
logging.debug('update split to ')
logging.debug(ent)
min_ent = ent
split_feat = feat
split_thresh = point[feat]
try:
feat_dist[label(point)] += 1
except:
feat_dist[label(point)] = 1
rem_dist[label(point)] -= 1
logging.debug('move on')
return (split_feat, split_thresh)
def test_purity(pointset):
prev_label = 0
curr_label = label(pointset[0])
for row in pointset:
prev_label = curr_label
curr_label = label(row)
if curr_label != prev_label:
return (False, None)
return (True, curr_label)
class Tree():
def __init__(self):
self.threshold = 0
self.feature = 0
self.isLeaf = False
self.label = -1
self.left = None
self.right = None
def make_decision_tree(pointset):
(isClean, nodeLabel) = test_purity(pointset)
node = Tree()
if isClean:
node.isLeaf = True
node.label = nodeLabel
logging.debug('made leaf!')
return node
else:
(feature, threshold) = find_split(pointset)
left_branch = [row for row in pointset if row[feature] <= threshold]
right_branch = [row for row in pointset if row[feature] > threshold]
print left_branch
print right_branch
sys.exit(0)
node.left = make_decision_tree(left_branch)
node.right = make_decision_tree(right_branch)
if __name__ == "__main__":
data = read_to_list(sys.argv[1])
root = make_decision_tree(data)