Skip to content

Latest commit

 

History

History
41 lines (40 loc) · 1.63 KB

2023-4-20.md

File metadata and controls

41 lines (40 loc) · 1.63 KB

Note

  • 决策树信息增益学习
  • 记忆python中numpy的相关用法,强化矩阵运算的思想,比数组遍历求解方便很多(比如乘法之类的运算)。
  • np.where,np.log2用法的学习
  • 回忆练习np.unique,return_counts,[:,i],np.sum用法

Code

决策树求信息增益

import numpy as np

def calcInfoGain(feature, label, index):
    '''
    计算信息增益
    :param feature:测试用例中字典里的feature,类型为ndarray
    :param label:测试用例中字典里的label,类型为ndarray
    :param index:测试用例中字典里的index,即feature部分特征列的索引。
    该索引指的是feature中第几个特征,如index:0表示使用第一个特征来计算信息增益。
    :return:信息增益,类型float
    '''
    values = np.unique(feature[:,index])
    entropy_parent = entropy(label)
    entropy_children = 0
    for v in values:
        indices = np.where(feature[:, index] == v)[0]
        labels_v = label[indices]
        entropy_v = entropy(labels_v)
        weight_v = len(indices) / len(feature)
        entropy_children += weight_v * entropy_v
    info_gain = entropy_parent - entropy_children
    return info_gain
    
def entropy(labels):
    _, counts = np.unique(labels, return_counts=True)
    probs = counts / len(labels)
    entropy = -np.sum(probs * np.log2(probs))
    return entropy
    
data={'feature':[[0, 1], [1, 0], [1, 2], [0, 0], [1, 1]], 'label':[0, 1, 0, 0, 1], 'index': 0}
print(calcInfoGain(np.array(data['feature']), np.array(data['label']), np.array(data['index'])))
# 0.419973