Annealed importance sampling

We can compute the partition function of the RBM (and hence the log-likelihood) with annealed importance sampling (AIS).

import MLDatasets
import Makie
import CairoMakie
import RestrictedBoltzmannMachines as RBMs
using Statistics: mean, std, middle
using ValueHistories: MVHistory
using RestrictedBoltzmannMachines: Binary, BinaryRBM, initialize!, pcd!,
    aise, raise, logmeanexp, logstdexp, sample_v_from_v

Load MNIST (0 digit only).

Float = Float32
train_x, train_y = MLDatasets.MNIST.traindata()
train_x = Array{Float}(train_x[:, :, train_y .== 0] .> 0.5)
┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:187

Train an RBM

rbm = BinaryRBM(Float, (28,28), 128)
initialize!(rbm, train_x)
@time pcd!(rbm, train_x; iters=10000, batchsize=128)
 74.041175 seconds (2.54 M allocations: 69.159 GiB, 21.62% gc time, 0.45% compilation time)

Get some equilibrated samples from model

v = train_x[:, :, rand(1:size(train_x, 3), 1000)]
v = sample_v_from_v(rbm, v; steps=1000)

Estimate Z with AIS and reverse AIS.

nsamples=100
ndists = [10, 100, 1000, 10_000, 100_000]
R_ais = Vector{Float64}[]
R_rev = Vector{Float64}[]
init = initialize!(Binary(; θ = zero(rbm.visible.θ)), v)

for nbetas in ndists
    push!(R_ais,
        @time aise(rbm; nbetas, nsamples, init)
    )
    push!(R_rev,
        @time raise(rbm; nbetas, init, v=v[:,:,rand(1:size(v, 3), nsamples)])
    )
end
  5.319742 seconds (9.03 M allocations: 518.432 MiB, 0.98% gc time, 99.35% compilation time)
  0.329384 seconds (416.62 k allocations: 81.739 MiB, 7.94% gc time, 82.68% compilation time)
  0.419613 seconds (12.82 k allocations: 703.543 MiB, 7.53% gc time)
  0.343557 seconds (12.80 k allocations: 731.346 MiB, 6.85% gc time)
  4.150691 seconds (129.82 k allocations: 6.979 GiB, 5.55% gc time)
  3.520097 seconds (129.81 k allocations: 7.269 GiB, 6.80% gc time)
 41.626537 seconds (1.30 M allocations: 69.899 GiB, 5.61% gc time)
 35.130993 seconds (1.30 M allocations: 72.818 GiB, 6.83% gc time)
416.738281 seconds (13.00 M allocations: 699.103 GiB, 5.60% gc time)
354.403215 seconds (13.00 M allocations: 728.305 GiB, 6.81% gc time)

Plots

fig = Makie.Figure()
ax = Makie.Axis(
    fig[1,1], width=700, height=400, xscale=log10, xlabel="interpolating distributions", ylabel="log(Z)"
)
Makie.band!(
    ax, ndists,
    mean.(R_ais) - std.(R_ais),
    mean.(R_ais) + std.(R_ais);
    color=(:blue, 0.25)
)
Makie.band!(
    ax, ndists,
    mean.(R_rev) - std.(R_rev),
    mean.(R_rev) + std.(R_rev);
    color=(:black, 0.25)
)
Makie.lines!(ax, ndists, mean.(R_ais); color=:blue, label="AIS")
Makie.lines!(ax, ndists, mean.(R_rev); color=:black, label="reverse AIS")
Makie.lines!(ax, ndists, logmeanexp.(R_ais); color=:blue, linestyle=:dash)
Makie.lines!(ax, ndists, logmeanexp.(R_rev); color=:black, linestyle=:dash)
Makie.lines!(ax, ndists, -logmeanexp.(-R_rev); color=:orange, linestyle=:dash)
Makie.hlines!(ax, middle(mean(R_ais[end]), mean(R_rev[end])), linestyle=:dash, color=:red, label="limiting estimate")
Makie.xlims!(extrema(ndists)...)
Makie.axislegend(ax, position=:rb)
Makie.resize_to_layout!(fig)
fig
Example block output

This page was generated using Literate.jl.