Annealed importance sampling
We can compute the partition function of the RBM (and hence the log-likelihood) with annealed importance sampling (AIS).
import MLDatasets
import Makie
import CairoMakie
import RestrictedBoltzmannMachines as RBMs
using Statistics: mean, std, middle
using ValueHistories: MVHistory
using RestrictedBoltzmannMachines: Binary, BinaryRBM, initialize!, pcd!,
aise, raise, logmeanexp, logstdexp, sample_v_from_vLoad MNIST (0 digit only).
Float = Float32
train_x, train_y = MLDatasets.MNIST.traindata()
train_x = Array{Float}(train_x[:, :, train_y .== 0] .> 0.5)┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:187Train an RBM
rbm = BinaryRBM(Float, (28,28), 128)
initialize!(rbm, train_x)
@time pcd!(rbm, train_x; iters=10000, batchsize=128) 80.473956 seconds (2.61 M allocations: 69.162 GiB, 27.94% gc time, 0.24% compilation time)Get some equilibrated samples from model
v = train_x[:, :, rand(1:size(train_x, 3), 1000)]
v = sample_v_from_v(rbm, v; steps=1000)Estimate Z with AIS and reverse AIS.
nsamples=100
ndists = [10, 100, 1000, 10_000, 100_000]
R_ais = Vector{Float64}[]
R_rev = Vector{Float64}[]
init = initialize!(Binary(; θ = zero(rbm.visible.θ)), v)
for nbetas in ndists
push!(R_ais,
@time aise(rbm; nbetas, nsamples, init)
)
push!(R_rev,
@time raise(rbm; nbetas, init, v=v[:,:,rand(1:size(v, 3), nsamples)])
)
end 3.364007 seconds (5.52 M allocations: 328.881 MiB, 1.25% gc time, 98.99% compilation time)
0.298726 seconds (440.21 k allocations: 81.901 MiB, 90.64% compilation time)
0.446922 seconds (13.12 k allocations: 703.575 MiB, 15.91% gc time)
0.353686 seconds (13.10 k allocations: 731.379 MiB, 9.33% gc time)
3.997330 seconds (132.82 k allocations: 6.979 GiB, 7.79% gc time)
3.585369 seconds (132.80 k allocations: 7.269 GiB, 8.43% gc time)
40.000132 seconds (1.33 M allocations: 69.903 GiB, 7.69% gc time)
36.323125 seconds (1.33 M allocations: 72.821 GiB, 8.70% gc time)
405.110259 seconds (13.30 M allocations: 699.135 GiB, 7.98% gc time)
355.180481 seconds (13.30 M allocations: 728.338 GiB, 8.40% gc time)Plots
fig = Makie.Figure()
ax = Makie.Axis(
fig[1,1], width=700, height=400, xscale=log10, xlabel="interpolating distributions", ylabel="log(Z)"
)
Makie.band!(
ax, ndists,
mean.(R_ais) - std.(R_ais),
mean.(R_ais) + std.(R_ais);
color=(:blue, 0.25)
)
Makie.band!(
ax, ndists,
mean.(R_rev) - std.(R_rev),
mean.(R_rev) + std.(R_rev);
color=(:black, 0.25)
)
Makie.lines!(ax, ndists, mean.(R_ais); color=:blue, label="AIS")
Makie.lines!(ax, ndists, mean.(R_rev); color=:black, label="reverse AIS")
Makie.lines!(ax, ndists, logmeanexp.(R_ais); color=:blue, linestyle=:dash)
Makie.lines!(ax, ndists, logmeanexp.(R_rev); color=:black, linestyle=:dash)
Makie.lines!(ax, ndists, -logmeanexp.(-R_rev); color=:orange, linestyle=:dash)
Makie.hlines!(ax, middle(mean(R_ais[end]), mean(R_rev[end])), linestyle=:dash, color=:red, label="limiting estimate")
Makie.xlims!(extrema(ndists)...)
Makie.axislegend(ax, position=:rb)
Makie.resize_to_layout!(fig)
fig
This page was generated using Literate.jl.