Skip to content

Commit 38dbd08

Browse files
Svm plot (#28)
* improved plot for 2-class SVM * added margin as well
1 parent e446479 commit 38dbd08

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

ISLP/svm.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def plot(X,
2828
'''
2929
Graphical representation of fitted support vector classifier.
3030
31+
There are two types of support vectors:
32+
33+
- Points violating the margin but correctly classified. These are marked with a black '+'.
34+
35+
- Misclassified points. These are marked with a red 'x'.
36+
3137
Parameters
3238
----------
3339
@@ -89,7 +95,7 @@ def plot(X,
8995

9096
# draw the points
9197

92-
ax.scatter(X0, X1, c=Y, cmap=scatter_cmap)
98+
ax.scatter(X0, X1, c=Y, cmap=scatter_cmap, s=200)
9399

94100
# add the contour
95101

@@ -113,8 +119,27 @@ def plot(X,
113119
cmap=decision_cmap,
114120
alpha=alpha)
115121

116-
# add the support vectors
122+
decision_val = svm.decision_function(X_pred)
117123

118-
ax.scatter(X[svm.support_,features[0]],
119-
X[svm.support_,features[1]], marker='+', c='k', s=200)
124+
# add the support vectors
120125

126+
if svm.classes_.shape[0] == 2: # 2-class problem
127+
128+
ax.contourf(xval,
129+
yval,
130+
decision_val.reshape(yval.shape),
131+
levels=[-1,1],
132+
cmap=decision_cmap,
133+
alpha=alpha)
134+
135+
D = svm.decision_function(X[svm.support_])
136+
Y_ = (2 * (Y[svm.support_] == svm.classes_[1]) - 1)
137+
violate_margin = (Y_ * D) > 0
138+
ax.scatter(X[svm.support_,features[0]][violate_margin],
139+
X[svm.support_,features[1]][violate_margin], marker='+', c='k', s=50)
140+
misclassified = ~violate_margin
141+
ax.scatter(X[svm.support_,features[0]][misclassified],
142+
X[svm.support_,features[1]][misclassified], marker='x', c='r', s=50)
143+
else:
144+
ax.scatter(X[svm.support_,features[0]],
145+
X[svm.support_,features[1]], marker='+', c='k', s=50)

0 commit comments

Comments
 (0)