Skip to content

Commit

Permalink
Sensitivity svm (#201)
Browse files Browse the repository at this point in the history
* sensitivity reg

* local model, Plots explicit

* direct assign

* block separation

* fix block

* fix computations

* comments

* improved SVM
  • Loading branch information
matbesancon authored Mar 21, 2022
1 parent ba669f5 commit 79869d8
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions docs/src/examples/sensitivity-analysis-svm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# This notebook illustrates sensitivity analysis of data points in a [Support Vector Machine](https://en.wikipedia.org/wiki/Support-vector_machine) (inspired from [@matbesancon](http://github.com/matbesancon)'s [SimpleSVMs](http://github.com/matbesancon/SimpleSVMs.jl).)

# For reference, Section 10.1 of https://online.stat.psu.edu/stat508/book/export/html/792 gives an intuitive explanation of what it means to have a sensitive hyperplane or data point. The general form of the SVM training problem is given below (without regularization):
# For reference, Section 10.1 of https://online.stat.psu.edu/stat508/book/export/html/792 gives an intuitive explanation of what it means to have a sensitive hyperplane or data point. The general form of the SVM training problem is given below (with $\ell_2$ regularization):

# ```math
# \begin{split}
Expand All @@ -19,25 +19,27 @@
# - `X`, `y` are the `N` data points
# - `w` is the support vector
# - `b` determines the offset `b/||w||` of the hyperplane with normal `w`
# - `ξ` is the soft-margin loss.

# - `ξ` is the soft-margin loss
# - `λ` is the $\ell_2$ regularization.
#
# This tutorial uses the following packages

using JuMP # The mathematical programming modelling language
import DiffOpt # JuMP extension for differentiable optimization
import Ipopt # Optimization solver that handles quadratic programs
import Plots # Graphing tool
import LinearAlgebra: dot, norm, normalize!
import LinearAlgebra: dot, norm
import Random

# ## Define and solve the SVM

# Construct separable, non-trivial data points.
# Construct two clusters of data points.

N = 100
D = 2

Random.seed!(62)
X = vcat(randn(N ÷ 2, D), randn(N ÷ 2, D) .+ [4.5, 2.0]')
X = vcat(randn(N ÷ 2, D), randn(N ÷ 2, D) .+ [2.0, 2.0]')
y = append!(ones(N ÷ 2), -ones(N ÷ 2))
λ = 0.05;

Expand Down Expand Up @@ -86,11 +88,10 @@ wv = value.(w)

bv = value(b)

svm_x = [0.0, 5.0] # arbitrary points
svm_x = [-2.0, 4.0] # arbitrary points
svm_y = (-bv .- wv[1] * svm_x )/wv[2]

p = Plots.scatter(X[:,1], X[:,2], color = [yi > 0 ? :red : :blue for yi in y], label = "")
Plots.yaxis!(p, (-2, 4.5))
Plots.plot!(p, svm_x, svm_y, label = "loss = $(round(loss, digits=2))", width=3)

# ## Gradient of hyperplane wrt the data point coordinates
Expand All @@ -101,25 +102,27 @@ Plots.plot!(p, svm_x, svm_y, label = "loss = $(round(loss, digits=2))", width=3)

# How does a change in coordinates of the data points, `X`,
# affects the position of the hyperplane?
# This is achieved by finding gradients of `w`, `b` with respect to `X[i]`,
# 2D coordinates of the data points.
# This is achieved by finding gradients of `w` and `b` with respect to `X[i]`.

# Begin differentiating the model.
# analogous to varying θ in the expression:
# ```math
# y_{i} (w^T (X_{i} + \theta) + b) \ge 1 - \xi_{i}
# ```
= zeros(N)
dX = zeros(N, D);
for i in 1:N
dX[i, :] = ones(D) # set
for j in 1:N
MOI.set(
model,
DiffOpt.ForwardInConstraint(),
cons[j],
y[j] * dot(dX[j,:], index.(w)),
)
if i == j
## we consider identical perturbations on all x_i coordinates
MOI.set(
model,
DiffOpt.ForwardInConstraint(),
cons[j],
y[j] * sum(w),
)
else
MOI.set(model, DiffOpt.ForwardInConstraint(), cons[j], 0.0 * sum(w))
end
end
DiffOpt.forward(model)
dw = MOI.get.(
Expand All @@ -133,19 +136,17 @@ for i in 1:N
b,
)
∇[i] = norm(dw) + norm(db)
dX[i, :] = zeros(D) # reset the change made at the beginning of the loop
end

normalize!(∇);

# We can visualize the separating hyperplane sensitivity with respect to the data points.
# Note that the norm of the gradients are normalized and all the small numbers
# were converted into 1/10 of the largest value to show all the points of the set.
# Note that all the small numbers were converted into 1/10 of the
# largest value to show all the points of the set.

p3 = Plots.scatter(
X[:,1], X[:,2],
color = [yi > 0 ? :red : :blue for yi in y], label = "",
markersize = 20 * max.(∇, 0.1 * maximum(∇)),
markersize = 2 * (max.(1.8∇, 0.2 * maximum(∇))),
)
Plots.yaxis!(p3, (-2, 4.5))
Plots.plot!(p3, svm_x, svm_y, label = "", width=3)
Plots.title!("Sensitivity of the separator to data point variations")

0 comments on commit 79869d8

Please sign in to comment.