Deconstructing Deep Learning + δeviations
Drop me an email
Format :
Date | Title
TL; DR
A simple Variational Auto Encoder using just what we made so far!!
Since we made everything in this from scratch already I will be using Flux directly. (I will attempt to do optimizers in the next post)
Okay so what is that? it is a neural network which is used for things like Image compression, Image generation etc etc. It is cool because it is pretty cheap computationally.
So what do we need to build it?
-
Let us break it down or I will break down.
We load all required libraries. Batch the data and split it to train and test.
using Flux, Flux.Data.MNIST, Statistics, Flux.Optimise
using Flux: throttle, params
X = (float.(hcat(vec.(MNIST.images())...)) .> 0.5)
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]
We need to pick something from the sampled space and also run our encoder. - (Dense(784, 500, tanh), Dense(500, 5), Dense(500, 5))
Dz, Dh = 5, 500
A, μ, logσ = Dense(28^2, Dh, tanh) , Dense(Dh, Dz) , Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
function sample_z(μ, logσ)
eps = randn(Float32, size(μ))
return μ + exp.(logσ) .* eps
end
We define the decoder here. - Chain(Dense(5, 500, tanh), Dense(500, 784, σ))
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))
kl_q_p(μ, logσ) = 0.5f0 * sum(exp.(2f0 .* logσ) + μ.^2 .- 1f0 .+ logσ.^2)
function logp_x_z(x, z)
p = f(z)
ll = x .* log.(p .+ eps(Float32)) + (1f0 .- x) .* log.(1 .- p .+ eps(Float32))
return sum(ll)
end
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, sample_z(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) * 1 // M)
Let us define the loss function. And also attempt to sample from the model.
loss(X) = -L̄(X) + 0.01f0 * sum(x->sum(x.^2), params(f))
function modelsample()
ϕ = zeros(Float32, Dz)
p = f(sample_z(ϕ, ϕ))
u = rand(Float32, size(p))
return (u .< p)
end
Now for the actual training. I will be using ADAM (yes cheating I know but I am trying really hard to get it done from scratch ): ). Also no GPU.
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 10)
opt = ADAM()
ps = params(A, μ, logσ, f)
for i = 1:10
@info "Epoch $i"
Flux.train!(loss, ps, zip(data), opt, cb=evalcb)
end
Finally let us visualize the outputs. Note that it was only for 10 epochs so it is kinda dumb but well.
using Images
img(x) = Gray.(reshape(x, 28, 28))
sample = hcat(img.([modelsample() for i = 1:10])...)
save("sample.png", sample)
-
Finis