Skip to content

Commit

Permalink
Merge pull request #6 from dfdx/vae2
Browse files Browse the repository at this point in the history
More flexible VAE
  • Loading branch information
dfdx authored Mar 9, 2021
2 parents b3124d4 + 74a7146 commit 3f06d63
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
_*
Manifest.toml
logs/
.vscode/
2 changes: 0 additions & 2 deletions benchmarks/vae/vae_lilith.jl → benchmarks/vae/vae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using MLDataUtils
using BenchmarkTools



# variational autoencoder with Gaussian observed and latent variables
mutable struct VAE
# encoder / recognizer
enc_i2h::Linear
Expand Down
18 changes: 14 additions & 4 deletions benchmarks/vae/vae_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

torch.manual_seed(42)
BATCH_SIZE = 100
N_EPOCHS = 10
N_EPOCHS = 100


class VAE(nn.Module):
Expand Down Expand Up @@ -61,13 +61,13 @@ def forward(self, x):

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='mean')

# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD

Expand Down Expand Up @@ -113,6 +113,12 @@ def test(epoch):


def main():
# due to https://github.com/pytorch/vision/issues/1938
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

kwargs = {'num_workers': 1, 'pin_memory': True}
device = torch.device("cuda")
train_loader = torch.utils.data.DataLoader(
Expand All @@ -124,7 +130,7 @@ def main():
batch_size=BATCH_SIZE, shuffle=True, **kwargs)

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(1, N_EPOCHS + 1):
train(epoch)
# test(epoch)
Expand All @@ -133,3 +139,7 @@ def main():
# sample = model.decode(sample).cpu()
# save_image(sample.view(64, 1, 28, 28),
# 'results/sample_' + str(epoch) + '.png')


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/Avalon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Avalon
export
grad,
@diffrule,
@diffrile_kw,
@diffrule_kw,
@nodiff,
# initialization
init_constant!,
Expand Down
30 changes: 29 additions & 1 deletion src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,42 @@
import ChainRulesCore


"""Should we ignore field with this path when updating the struct?"""
ignored_field(st::S, path::Val) where S = false


_dot_path(x::Symbol) = [x]

function _dot_path(ex::Expr)
subresult = _dot_path(ex.args[1])
return [subresult; [ex.args[2].value]]
end

"""
Ignore this field when updating the struct. Example:
mutable struct MyModel
linear::Linear
end
@ignore MyModel.linear.b
"""
macro ignore(ex)
@assert Meta.isexpr(ex, :(.))
full_path = _dot_path(ex)
T, path = full_path[1], (full_path[2:end]...,)
return esc(:($ignored_field(st::$T, path::Val{$path}) = true))
end


abstract type Optimizer end


function Yota.update!(opt::Optimizer, m, gm; ignore=Set())
# we use Zero() to designate values that need not to be updated
gm isa ChainRulesCore.Zero && return
for (path, gx) in Yota.path_value_pairs(gm)
if !in(path, ignore) && !isa(gx, ChainRulesCore.Zero)
if !in(path, ignore) && !ignored_field(m, Val(path)) && !isa(gx, ChainRulesCore.Zero)
x_t0 = Yota.getfield_nested(m, path)
x_t1 = make_update!(opt, path, x_t0, gx)
Yota.setfield_nested!(m, path, x_t1)
Expand Down
69 changes: 30 additions & 39 deletions zoo/vae/main.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,50 @@
using Avalon
import Avalon.fit!
using Distributions
# using GradDescent
using MLDataUtils
using MLDatasets
# using StatsBase
using MLDataUtils
using BenchmarkTools
using ImageView




# variational autoencoder with Gaussian observed and latent variables
mutable struct VAE
# encoder / recognizer
enc_l1::Linear # encoder: layer 1
enc_l2::Linear # encoder: layer 2
enc_l3::Linear # encoder: mu
enc_l4::Linear # encoder: log(sigma^2)
enc_i2h::Linear
enc_h2mu::Linear
enc_h2logsigma2::Linear
# decoder / generator
dec_l1::Linear # decoder: layer 1
dec_l2::Linear # decoder: layer 2
dec_l3::Linear # decoder: layer 3
dec_z2h::Linear
dec_h2o::Linear
end


function Base.show(io::IO, m::VAE)
print(io, "VAE($(size(m.enc_l1.W,2)), $(size(m.enc_l1.W,1)), $(size(m.enc_l2.W,1)), " *
"$(size(m.enc_l3.W,1)), $(size(m.dec_l2.W,1)), $(size(m.dec_l2.W,1)), $(size(m.dec_l3.W,1)))")
print(io, "VAE()")
end


VAE(n_inp, n_he1, n_he2, n_z, n_hd1, n_hd2, n_out) =
VAE(n_i, n_h, n_z) =
VAE(
# encoder
Linear(n_inp, n_he1),
Linear(n_he1, n_he2),
Linear(n_he2, n_z),
Linear(n_he2, n_z),
Linear(n_i, n_h),
Linear(n_h, n_z),
Linear(n_h, n_z),
# decoder
Linear(n_z, n_hd1),
Linear(n_hd1, n_hd2),
Linear(n_hd2, n_out)
Linear(n_z, n_h),
Linear(n_h, n_i),
)


function encode(m::VAE, x)
he1 = softplus.(m.enc_l1(x))
he2 = softplus.(m.enc_l2(he1))
mu = m.enc_l3(he2)
log_sigma2 = m.enc_l4(he2)
he1 = tanh.(m.enc_i2h(x))
mu = m.enc_h2mu(he1)
log_sigma2 = m.enc_h2logsigma2(he1)
return mu, log_sigma2
end

function decode(m::VAE, z)
hd1 = softplus.(m.dec_l1(z))
hd2 = softplus.(m.dec_l2(hd1))
x_rec = logistic.(m.dec_l3(hd2))
hd1 = tanh.(m.dec_z2h(z))
x_rec = logistic.(m.dec_h2o(hd1))
return x_rec
end

Expand All @@ -65,19 +54,20 @@ function vae_cost(m::VAE, eps, x)
z = mu .+ sqrt.(exp.(log_sigma2)) .* eps
x_rec = decode(m, z)
# loss
rec_loss = -sum(x .* log.(1e-10 .+ x_rec) .+ (1 .- x) .* log.(1e-10 + 1.0 .- x_rec); dims=1) # BCE
KLD = -0.5 .* sum(1 .+ log_sigma2 .- mu .^ 2.0f0 - exp.(log_sigma2); dims=1)
rec_loss = -sum(x .* log.(1f-10 .+ x_rec) .+ (1 .- x) .* log.(1f-10 + 1 .- x_rec); dims=1)
KLD = -0.5f0 .* sum(1 .+ log_sigma2 .- mu .^ 2.0f0 - exp.(log_sigma2); dims=1)
cost = mean(rec_loss .+ KLD)
end


function fit!(m::VAE, X::AbstractMatrix{T};
n_epochs=50, batch_size=100, opt=SGD(1e-4; momentum=0)) where T
n_epochs=50, batch_size=100, opt=SGD(1e-4; momentum=0), device=CPU()) where T
for epoch in 1:n_epochs
print("Epoch $epoch: ")
epoch_cost = 0
t = @elapsed for (i, x) in enumerate(eachbatch(X, size=batch_size))
eps = typeof(x)(rand(Normal(0, 1), size(m.enc_l3.W, 1), batch_size))
x = device(x)
eps = typeof(x)(rand(Normal(0, 1), size(m.enc_h2mu.W, 1), batch_size))
cost, g = grad(vae_cost, m, eps, x)
update!(opt, m, g[1])
epoch_cost += cost
Expand All @@ -102,22 +92,23 @@ function show_pic(x)
end


function show_recon(m, x)
x_ = reconstruct(m, x)
function show_recon(m, x, device)
x_ = reconstruct(m, device(x))
show_pic(x)
show_pic(x_)
end


function main()
m = VAE(784, 500, 500, 20, 500, 500, 784)
device = best_available_device()
m = VAE(784, 500, 5) |> device

X, _ = MNIST.traindata()
X = convert(Matrix{Float64}, reshape(X, 784, 60000))
@time m = fit!(m, X)
@time m = fit!(m, X, device=device)

# check reconstructed image
for i=1:2:10
show_recon(m, X[:, i])
show_recon(m, X[:, i], device)
end
end
95 changes: 95 additions & 0 deletions zoo/vae/main2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using MLDatasets
using MLDataUtils
using Plots
using Images
using Interact


include("vae2.jl")


function reconstruct(m::VAE, x::AbstractVector)
x = reshape(x, length(x), 1)
x_rec = decode(m, m(x))
return x_rec
end


function show_pic(x)
a = reshape(x, 28, 28)'
return plot(Gray.(a))
end


function show_recon(m, X, device; n=5)
subplots = []
cpu = CPU()
for i in rand(1:size(X, 2), n)
x = X[:, i]
x_ = reconstruct(m, device(x))
p = show_pic(cpu(x))
p_ = show_pic(cpu(x_))
push!(subplots, p, p_)
end
plot(subplots..., layout=(n, 2))
end


function interpolate_latent_var(m, x, z_idx, device)
vals = collect(-2:0.5:2)
z = m(x)
xs_ = []
cpu = CPU()
for v in vals
z[z_idx, :] = v
x_ = decode(m, device(z)) |> cpu
push!(xs_, x_)
end
subplots = [show_pic(x_) for x_ in xs_]
# plot(subplots..., layout=length(subplots))
return subplots
end


function show_latent_vars(m, x, z_idxs, device)
groups = [interpolate_latent_var(m, x, z_idx, device) for z_idx in z_idxs]
subplots = vcat(groups...)
n_cols = length(groups[1])
plot(subplots..., layout=(length(z_idxs), n_cols))
end


function show_samples(m, n, device)
z_len = length(m.enc2mu.b) # assuming Linear
z = randn(z_len, n)
z = device(z)
x_ = decode(m, z) |> CPU()
subplots = [show_pic(x_[:, i]) for i in 1:size(x_, 2)]
plot(subplots..., layout=length(subplots))
end


function main()
device = best_available_device()
m = VAE(
Sequential(
Linear(784 => 400),
x -> relu.(x)),
Linear(400 => 20),
Linear(400 => 20),
Sequential(
Linear(20 => 400),
x -> relu.(x),
Linear(400 => 784),
x -> logistic.(x));
beta=5)
m = m |> device

X, _ = MNIST.traindata()
X = convert(Matrix{Float64}, reshape(X, 784, 60000))
@time m = fit!(m, X, device=device, opt=Adam(; lr=1e-3), n_epochs=10)

show_recon(m, X, device, n=5)
show_latent_vars(m, device(X[:, 2]), 1:4, device)
show_samples(m, 10, device)
end
Loading

0 comments on commit 3f06d63

Please sign in to comment.