@@ -28,6 +28,12 @@ def plot(X,
28
28
'''
29
29
Graphical representation of fitted support vector classifier.
30
30
31
+ There are two types of support vectors:
32
+
33
+ - Points violating the margin but correctly classified. These are marked with a black '+'.
34
+
35
+ - Misclassified points. These are marked with a red 'x'.
36
+
31
37
Parameters
32
38
----------
33
39
@@ -89,7 +95,7 @@ def plot(X,
89
95
90
96
# draw the points
91
97
92
- ax .scatter (X0 , X1 , c = Y , cmap = scatter_cmap )
98
+ ax .scatter (X0 , X1 , c = Y , cmap = scatter_cmap , s = 200 )
93
99
94
100
# add the contour
95
101
@@ -113,8 +119,27 @@ def plot(X,
113
119
cmap = decision_cmap ,
114
120
alpha = alpha )
115
121
116
- # add the support vectors
122
+ decision_val = svm . decision_function ( X_pred )
117
123
118
- ax .scatter (X [svm .support_ ,features [0 ]],
119
- X [svm .support_ ,features [1 ]], marker = '+' , c = 'k' , s = 200 )
124
+ # add the support vectors
120
125
126
+ if svm .classes_ .shape [0 ] == 2 : # 2-class problem
127
+
128
+ ax .contourf (xval ,
129
+ yval ,
130
+ decision_val .reshape (yval .shape ),
131
+ levels = [- 1 ,1 ],
132
+ cmap = decision_cmap ,
133
+ alpha = alpha )
134
+
135
+ D = svm .decision_function (X [svm .support_ ])
136
+ Y_ = (2 * (Y [svm .support_ ] == svm .classes_ [1 ]) - 1 )
137
+ violate_margin = (Y_ * D ) > 0
138
+ ax .scatter (X [svm .support_ ,features [0 ]][violate_margin ],
139
+ X [svm .support_ ,features [1 ]][violate_margin ], marker = '+' , c = 'k' , s = 50 )
140
+ misclassified = ~ violate_margin
141
+ ax .scatter (X [svm .support_ ,features [0 ]][misclassified ],
142
+ X [svm .support_ ,features [1 ]][misclassified ], marker = 'x' , c = 'r' , s = 50 )
143
+ else :
144
+ ax .scatter (X [svm .support_ ,features [0 ]],
145
+ X [svm .support_ ,features [1 ]], marker = '+' , c = 'k' , s = 50 )
0 commit comments