-
Notifications
You must be signed in to change notification settings - Fork 447
/
Copy path8.5-StumpBagging.py
102 lines (68 loc) · 2.82 KB
/
8.5-StumpBagging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.utils import resample
def stumpClassify(X, dim, thresh_val, thresh_inequal):
ret_array = np.ones((X.shape[0], 1))
if thresh_inequal == 'lt':
ret_array[X[:, dim] <= thresh_val] = -1
else:
ret_array[X[:, dim] > thresh_val] = -1
return ret_array
def buildStump(X, y):
m, n = X.shape
best_stump = {}
min_error = 1
for dim in range(n):
x_min = np.min(X[:, dim])
x_max = np.max(X[:, dim])
# 这里第一次尝试使用排序后的点作为分割点,效果很差,因为那样会错过一些更好的分割点;
# 所以后来切割点改成将最大值和最小值之间分割成20等份。
# sorted_x = np.sort(X[:, dim])
# split_points = [(sorted_x[i] + sorted_x[i + 1]) / 2 for i in range(m - 1)]
split_points = [(x_max - x_min) / 20 * i + x_min for i in range(20)]
for inequal in ['lt', 'gt']:
for thresh_val in split_points:
ret_array = stumpClassify(X, dim, thresh_val, inequal)
error = np.mean(ret_array != y)
if error < min_error:
best_stump['dim'] = dim
best_stump['thresh'] = thresh_val
best_stump['inequal'] = inequal
best_stump['error'] = error
min_error = error
return best_stump
def stumpBagging(X, y, nums=20):
stumps = []
seed = 16
for _ in range(nums):
X_, y_ = resample(X, y, random_state=seed) # sklearn 中自带的实现自助采样的方法
seed += 1
stumps.append(buildStump(X_, y_))
return stumps
def stumpPredict(X, stumps):
ret_arrays = np.ones((X.shape[0], len(stumps)))
for i, stump in enumerate(stumps):
ret_arrays[:, [i]] = stumpClassify(X, stump['dim'], stump['thresh'], stump['inequal'])
return np.sign(np.sum(ret_arrays, axis=1))
def pltStumpBaggingDecisionBound(X_, y_, stumps):
pos = y_ == 1
neg = y_ == -1
x_tmp = np.linspace(0, 1, 600)
y_tmp = np.linspace(-0.1, 0.7, 600)
X_tmp, Y_tmp = np.meshgrid(x_tmp, y_tmp)
Z_ = stumpPredict(np.c_[X_tmp.ravel(), Y_tmp.ravel()], stumps).reshape(X_tmp.shape)
plt.contour(X_tmp, Y_tmp, Z_, [0], colors='orange', linewidths=1)
plt.scatter(X_[pos, 0], X_[pos, 1], label='1', color='c')
plt.scatter(X_[neg, 0], X_[neg, 1], label='0', color='lightcoral')
plt.legend()
plt.show()
if __name__ == "__main__":
data_path = r'..\data\watermelon3_0a_Ch.txt'
data = pd.read_table(data_path, delimiter=' ')
X = data.iloc[:, :2].values
y = data.iloc[:, 2].values
y[y == 0] = -1
stumps = stumpBagging(X, y, 21)
print(np.mean(stumpPredict(X, stumps) == y))
pltStumpBaggingDecisionBound(X, y, stumps)