Metropolis Hasting MCMC#
In this tutorials, we sample a multi-variate Gaussian with MCMC based on Metropolis Hasting algorithm.
We illustrate how to use ensemble sampling as well as parallel tempering.
import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import jexplore.tools.distributions as d
from jexplore.sampling import SamplingMH, Box
from jexplore.tools.diagnostic import auto_correlation_length, gelman_rubin_statistic
%reload_ext autoreload
%autoreload 2
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 5
1 import jax
2 import jax.numpy as jnp
3 import numpy as np
4
----> 5 import seaborn as sns
6 import pandas as pd
7 import matplotlib.pyplot as plt
8
ModuleNotFoundError: No module named 'seaborn'
Preliminary definitions#
Here we define our sampling ingredients:
We make use of a pre-defined Gaussian log-likelihood from
jexploretoolbox ;We define the number of walkers of the ensemble and the temperature ladder
We make use the built-in
stretchmove as a proposal, which is proposed as default
dim = 2
nwalker = 10
temperature_ladder = jnp.arange(1, 10, 3)
ntemp = len(temperature_ladder)
sampling = SamplingMH(
space=Box(dim=dim, size=10.0),
nwalker=nwalker,
loglik=d.Normal(dim=dim).leval,
temps=temperature_ladder,
logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,
)
sampler = sampling.get_sampler()
Run MCMC#
We start from a random draw of the normal distribution for each walker and temperature.
The JaxSampler starts from and returns Epoch : an encapsulation of actual samples and their meta data, like covariance, log-likelihood, sampling statistics ans so on.
In this example, we run for one epoch (nepoch=1), one has the choice of:
running several epoch at one time by setting
nepoch> 1pausing and resuming MCMC from the returned epoch
# drawing a starting point
key = jax.random.key(42)
p0 = d.Normal(dim=dim).sample(key, shape=(nwalker * ntemp,))[1]
epoch = sampling.get_epoch(p0)
niters = 5_000
burn = 500
epoch = sampler.run(epoch, niters=niters + burn, nepoch=1)
chain = epoch.samples.p[:, :, burn:]
chain.shape
Convergence diagnostics#
gr = np.max(np.max(np.abs(np.sqrt(gelman_rubin_statistic(chain, ntemp=ntemp)) - 1)))
acl = np.max(auto_correlation_length(chain))
print(f"Gelman Rubin ratio: {gr}. Autocorrelation length: {acl}")
_pdf = np.moveaxis(chain, [1], [0]).reshape(dim, -1)
sns.pairplot(pd.DataFrame(_pdf.T), corner=True)
## Acceptance rates
N = nwalker * ntemp
print(f"stretch acceptance rate of cold chains: {np.mean(epoch.stats.stats[1,np.arange(0,N,ntemp)] / niters)}")
print(f"stretch acceptance rate of hot chains: {np.mean(epoch.stats.stats[1,np.arange(ntemp,N,ntemp)] / niters)}")
print(f"swap rate between t0-t1: {np.mean(epoch.stats.stats[0,np.arange(1,N,ntemp)] / niters)}")
print(f"swap rate between t1-t2: {np.mean(epoch.stats.stats[0,np.arange(2,N,ntemp)] / niters)}")
Going further#
import h5py
from jexplore.sampler import Steps
from jexplore.steps import DEStep, Stretch, TSwap
from jexplore.steps.rwalk import GaussianRandomWalk, StudentTRandomWalk
from jexplore.backends import DefaultBackend
More proposals#
sampler = sampling.get_sampler(steps=[
{TSwap(permute=True).builder: 1.0}, # proposal builder and its weight
{Stretch(permute=True).builder: 1.0},
{GaussianRandomWalk().builder: 1.0},
{StudentTRandomWalk().builder: 1.0},
{DEStep(permute=True).builder: 1.0},
])
Backend customization#
We explicitely define the backend configuration and use it to retrieve all the informations stored on disk.
As an illustration, we run for 2 epochs, and use one h5 file per epoch.
config = {
"outdir": "./", # where to save h5 files
"outfile_fmt": "epochs_%d.h5", # h5 file template name
"epochs_per_file": 1,
"inmem_epochs": 0,
"save_stats": True,
}
burn = 500 # number of samples which won't be registered
my_backend = DefaultBackend(burn=burn, config=config)
sampler = sampling.get_sampler(backend=my_backend)
epoch = sampler.run(epoch, niters=niters + burn, nepoch=2)
stacked_chain = my_backend.get_samples()["p"]
stacked_chain.shape # 2 epoch of niters = 50 000 samples, first 5000 burn removed, second 5000 burn included
depoch = my_backend.get_samples(pars=["p", "ll", "lp"]) #return log-likelihood and log-prior too
depoch["ll"].shape
# getting cov and statistic from file
fn_epoch0 = my_backend._get_fname(1)
with h5py.File(fn_epoch0) as fid:
print(f"cov shape: {fid["epoch_0"]["covs"].shape}")
print(f"stats info: {fid["epoch_0"]["stats"].keys()}")
Cyclic parameters#
sampling = SamplingMH(
space=Box(dim=dim, size=10.0, wrapped=[0,1]), #wrapped: list of periodic dimensions indexes.
nwalker=nwalker,
loglik=d.Normal(dim=dim).leval,
temps=temperature_ladder,
logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,
)
sampler = sampling.get_sampler()