MNIST

We begin by importing the required packages. We load MNIST via the MLDatasets.jl package.

import CairoMakie
import Makie
import MLDatasets
using Random: bitrand
using RestrictedBoltzmannMachines: BinaryRBM
using RestrictedBoltzmannMachines: free_energy
using RestrictedBoltzmannMachines: initialize!
using RestrictedBoltzmannMachines: log_pseudolikelihood
using RestrictedBoltzmannMachines: pcd!
using RestrictedBoltzmannMachines: sample_from_inputs
using RestrictedBoltzmannMachines: sample_v_from_v
using Statistics: mean
using Statistics: std
using Statistics: var
using ValueHistories: @trace
using ValueHistories: MVHistory

Useful function to plot grids of MNIST digits.

"""
    imggrid(A)

Given a four dimensional tensor `A` of size `(width, height, ncols, nrows)`
containing `width x height` images in a grid of `nrows x ncols`, this returns
a matrix of size `(width * ncols, height * nrows)`, that can be plotted in a heatmap
to display all images.
"""
imggrid(A::AbstractArray{<:Any,4}) =
    reshape(permutedims(A, (1,3,2,4)), size(A,1)*size(A,3), size(A,2)*size(A,4))
Main.imggrid

Load the MNIST dataset. We will train an RBM with binary (0,1) visible and hidden units. Therefore we binarize the data. In addition, we consider only one kind of digit so that training is faster.

Float = Float32
train_x = MLDatasets.MNIST(split=:train)[:].features
train_y = MLDatasets.MNIST(split=:train)[:].targets
train_x = Array{Float}(train_x[:, :, train_y .== 0] .≥ 0.5)

Let's visualize some random digits.

nrows, ncols = 10, 15
fig = Makie.Figure(resolution=(40ncols, 40nrows))
ax = Makie.Axis(fig[1,1], yreversed=true)
idx = rand(1:size(train_x,3), nrows * ncols) # random indices of digits
digits = reshape(train_x[:,:,idx], 28, 28, ncols, nrows)
Makie.image!(ax, imggrid(digits), colorrange=(1,0))
Makie.hidedecorations!(ax)
Makie.hidespines!(ax)
fig
Example block output

Initialize an RBM with 400 hidden units.

rbm = BinaryRBM(Float, (28,28), 400)
initialize!(rbm, train_x) # match single-site statistics

Initially, the RBM assigns a poor pseudolikelihood to the data.

println("log(PL) = ", mean(@time log_pseudolikelihood(rbm, train_x)))
  2.084595 seconds (8.48 M allocations: 480.365 MiB, 5.99% gc time, 89.89% compilation time)
log(PL) = -0.25313258

Now we train the RBM on the data.

batchsize = 256
iters = 10000
history = MVHistory()
@time pcd!(
    rbm, train_x; iters, batchsize,
    callback = function(; iter, _...)
        if iszero(iter % 100)
            lpl = mean(log_pseudolikelihood(rbm, train_x))
            @trace history iter lpl
        end
    end
)
278.398476 seconds (19.84 M allocations: 205.539 GiB, 22.00% gc time, 6.05% compilation time)

After training, the pseudolikelihood score of the data improves significantly. Plot of log-pseudolikelihood of trian data during learning.

fig = Makie.Figure(resolution=(500,300))
ax = Makie.Axis(fig[1,1], xlabel = "train time", ylabel="pseudolikelihood")
Makie.lines!(ax, get(history, :lpl)...)
fig
Example block output

Sample digits from the RBM starting from a random condition.

nsteps = 3000
fantasy_F = zeros(nrows*ncols, nsteps)
fantasy_x = bitrand(28,28,nrows*ncols)
fantasy_F[:,1] .= free_energy(rbm, fantasy_x)
@time for t in 2:nsteps
    fantasy_x .= sample_v_from_v(rbm, fantasy_x)
    fantasy_F[:,t] .= free_energy(rbm, fantasy_x)
end
 17.698066 seconds (244.20 k allocations: 11.988 GiB, 2.52% gc time, 1.17% compilation time)

Check equilibration of sampling

fig = Makie.Figure(resolution=(400,300))
ax = Makie.Axis(fig[1,1], xlabel="sampling time", ylabel="free energy")
fantasy_F_μ = vec(mean(fantasy_F; dims=1))
fantasy_F_σ = vec(std(fantasy_F; dims=1))
Makie.band!(ax, 1:nsteps, fantasy_F_μ - fantasy_F_σ/2, fantasy_F_μ + fantasy_F_σ/2)
Makie.lines!(ax, 1:nsteps, fantasy_F_μ)
fig
Example block output

Plot the sampled digits.

fig = Makie.Figure(resolution=(40ncols, 40nrows))
ax = Makie.Axis(fig[1,1], yreversed=true)
Makie.image!(ax, imggrid(reshape(fantasy_x, 28, 28, ncols, nrows)), colorrange=(1,0))
Makie.hidedecorations!(ax)
Makie.hidespines!(ax)
fig
Example block output

This page was generated using Literate.jl.