- 决策树信息增益学习
- 记忆python中numpy的相关用法,强化矩阵运算的思想,比数组遍历求解方便很多(比如乘法之类的运算)。
np.where,np.log2
用法的学习
- 回忆练习
np.unique,return_counts,[:,i],np.sum
用法
import numpy as np
def calcInfoGain(feature, label, index):
'''
计算信息增益
:param feature:测试用例中字典里的feature,类型为ndarray
:param label:测试用例中字典里的label,类型为ndarray
:param index:测试用例中字典里的index,即feature部分特征列的索引。
该索引指的是feature中第几个特征,如index:0表示使用第一个特征来计算信息增益。
:return:信息增益,类型float
'''
values = np.unique(feature[:,index])
entropy_parent = entropy(label)
entropy_children = 0
for v in values:
indices = np.where(feature[:, index] == v)[0]
labels_v = label[indices]
entropy_v = entropy(labels_v)
weight_v = len(indices) / len(feature)
entropy_children += weight_v * entropy_v
info_gain = entropy_parent - entropy_children
return info_gain
def entropy(labels):
_, counts = np.unique(labels, return_counts=True)
probs = counts / len(labels)
entropy = -np.sum(probs * np.log2(probs))
return entropy
data={'feature':[[0, 1], [1, 0], [1, 2], [0, 0], [1, 1]], 'label':[0, 1, 0, 0, 1], 'index': 0}
print(calcInfoGain(np.array(data['feature']), np.array(data['label']), np.array(data['index'])))
# 0.419973