Skip to content

Commit

Permalink
First attempt at adding a sparsity constraint and weight decay.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Jan 11, 2016
1 parent eb05cb0 commit f5fad51
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions src/rbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@ abstract AbstractRBM{V,H}
dW_prev::Matrix{Float64}
persistent_chain::Matrix{Float64}
momentum::Float64
decay_rate::Float64
sparsity_cost::Float64
sparsity_exp::Float64
end


function RBM(V::Type, H::Type,
n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9)
n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9,
decay_rate=0.0, sparsity_cost=0.0, sparsity_exp=0.5)
RBM{V,H}(rand(Normal(0, sigma), (n_hid, n_vis)),
zeros(n_vis), zeros(n_hid),
zeros(n_hid, n_vis),
Array(Float64, 0, 0),
momentum)
momentum,
decay_rate,
sparsity_cost,
sparsity_exp
)
end


Expand All @@ -40,11 +48,24 @@ end


typealias BernoulliRBM RBM{Bernoulli, Bernoulli}
BernoulliRBM(n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9) =
RBM(Bernoulli, Bernoulli, n_vis, n_hid, sigma=sigma, momentum=momentum)
function BernoulliRBM(n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9,
decay_rate=0.0, sparsity_cost=0.0, sparsity_exp=0.5)
RBM(
Bernoulli, Bernoulli, n_vis, n_hid;
sigma=sigma, momentum=momentum, decay_rate=decay_rate,
sparsity_cost=sparsity_cost, sparsity_exp=sparsity_exp
)
end

typealias GRBM RBM{Gaussian, Bernoulli}
GRBM(n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9) =
RBM(Gaussian, Bernoulli, n_vis, n_hid, sigma=sigma, momentum=momentum)
function GRBM(n_vis::Int, n_hid::Int; sigma=0.001, momentum=0.9,
decay_rate=0.0, sparsity_cost=0.0, sparsity_exp=0.5)
RBM(
Gaussian, Bernoulli, n_vis, n_hid;
sigma=sigma, momentum=momentum, decay_rate=decay_rate,
sparsity_cost=sparsity_cost, sparsity_exp=sparsity_exp
)
end


function logistic(x)
Expand Down Expand Up @@ -132,11 +153,37 @@ end

function update_weights!(rbm, h_pos, v_pos, h_neg, v_neg, lr, buf)
dW = buf

# Scale costs
dr = rbm.decay_rate/size(v_pos,2)
cost = rbm.sparsity_cost/size(v_pos,2)
mean_hid = mean(h_pos)

# The sparsity constraint should only drive the weights
# down when the mean activation of hidden units is higher
# than the expected (hence why it isn't squared or the abs())
sparsity_penalty = cost * (mean_hid - rbm.sparsity_exp)

# The decay penalty should drive all weights toward
# zero by some small amount on each update.
decay_penalty = dr * rbm.W

# dW = (h_pos * v_pos') - (h_neg * v_neg')
gemm!('N', 'T', lr, h_neg, v_neg, 0.0, dW)
gemm!('N', 'T', lr, h_pos, v_pos, -1.0, dW)
# rbm.dW += rbm.momentum * rbm.dW_prev
axpy!(rbm.momentum, rbm.dW_prev, dW)

# Apply our penalties to dW.
# There is probably a more efficient way to do
# this with BLAS, but I'm very familiar with it.

This comment has been minimized.

Copy link
@rofinn

rofinn Jan 11, 2016

Author Owner

*NOT very familiar

dW -= (decay_penalty - sparsity_penalty)

# We also need to apply the sparsity penalty to the hidden bias.
# Not the best place to do this, but we're already calculating
# the penalty here.
rbm.hbias -= sparsity_penalty

# rbm.W += lr * dW
axpy!(1.0, dW, rbm.W)
# save current dW
Expand Down

0 comments on commit f5fad51

Please sign in to comment.