Metropolis Hasting MCMC#

In this tutorials, we sample a multi-variate Gaussian with MCMC based on Metropolis Hasting algorithm.

We illustrate how to use ensemble sampling as well as parallel tempering.

import jax
import jax.numpy as jnp
import numpy as np

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import jexplore.tools.distributions as d
from jexplore.sampling import SamplingMH, Box
from jexplore.tools.diagnostic import auto_correlation_length, gelman_rubin_statistic

%reload_ext autoreload
%autoreload 2 
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 5
      1 import jax
      2 import jax.numpy as jnp
      3 import numpy as np
      4 
----> 5 import seaborn as sns
      6 import pandas as pd
      7 import matplotlib.pyplot as plt
      8 

ModuleNotFoundError: No module named 'seaborn'

Preliminary definitions#

Here we define our sampling ingredients:

  • We make use of a pre-defined Gaussian log-likelihood from jexplore toolbox ;

  • We define the number of walkers of the ensemble and the temperature ladder

  • We make use the built-in stretch move as a proposal, which is proposed as default

dim = 2
nwalker = 10
temperature_ladder = jnp.arange(1, 10, 3) 
ntemp = len(temperature_ladder)
sampling =  SamplingMH(
            space=Box(dim=dim, size=10.0),
            nwalker=nwalker,
            loglik=d.Normal(dim=dim).leval,
            temps=temperature_ladder, 
            logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,
        )
sampler = sampling.get_sampler()

Run MCMC#

We start from a random draw of the normal distribution for each walker and temperature.

The JaxSampler starts from and returns Epoch : an encapsulation of actual samples and their meta data, like covariance, log-likelihood, sampling statistics ans so on.

In this example, we run for one epoch (nepoch=1), one has the choice of:

  • running several epoch at one time by setting nepoch > 1

  • pausing and resuming MCMC from the returned epoch

# drawing a starting point
key =  jax.random.key(42)
p0 =  d.Normal(dim=dim).sample(key, shape=(nwalker * ntemp,))[1]
epoch = sampling.get_epoch(p0)

niters = 5_000
burn = 500
epoch = sampler.run(epoch, niters=niters + burn, nepoch=1)
chain = epoch.samples.p[:, :, burn:] 
chain.shape

Convergence diagnostics#

gr = np.max(np.max(np.abs(np.sqrt(gelman_rubin_statistic(chain, ntemp=ntemp)) - 1)))
acl = np.max(auto_correlation_length(chain))
print(f"Gelman Rubin ratio: {gr}. Autocorrelation length: {acl}")
_pdf = np.moveaxis(chain, [1], [0]).reshape(dim, -1)
sns.pairplot(pd.DataFrame(_pdf.T), corner=True)
## Acceptance rates
N = nwalker * ntemp
print(f"stretch acceptance rate of cold chains: {np.mean(epoch.stats.stats[1,np.arange(0,N,ntemp)] / niters)}")
print(f"stretch acceptance rate of hot chains: {np.mean(epoch.stats.stats[1,np.arange(ntemp,N,ntemp)] / niters)}")

print(f"swap rate between t0-t1: {np.mean(epoch.stats.stats[0,np.arange(1,N,ntemp)] / niters)}")
print(f"swap rate between t1-t2: {np.mean(epoch.stats.stats[0,np.arange(2,N,ntemp)] / niters)}")

Going further#

import h5py 
from jexplore.sampler import Steps
from jexplore.steps import DEStep, Stretch, TSwap
from jexplore.steps.rwalk import GaussianRandomWalk, StudentTRandomWalk

from jexplore.backends import DefaultBackend

More proposals#

sampler = sampling.get_sampler(steps=[
     {TSwap(permute=True).builder: 1.0}, # proposal builder and its weight
     {Stretch(permute=True).builder: 1.0},
     {GaussianRandomWalk().builder: 1.0}, 
     {StudentTRandomWalk().builder: 1.0}, 
     {DEStep(permute=True).builder: 1.0}, 
])

Backend customization#

We explicitely define the backend configuration and use it to retrieve all the informations stored on disk.

As an illustration, we run for 2 epochs, and use one h5 file per epoch.

config = {
    "outdir": "./", # where to save h5 files
    "outfile_fmt": "epochs_%d.h5", # h5 file template name
    "epochs_per_file": 1,
    "inmem_epochs": 0, 
    "save_stats": True,
}
burn = 500 # number of samples which won't be registered
my_backend = DefaultBackend(burn=burn, config=config)
sampler = sampling.get_sampler(backend=my_backend)
epoch = sampler.run(epoch, niters=niters + burn, nepoch=2)
stacked_chain = my_backend.get_samples()["p"]
stacked_chain.shape # 2 epoch of niters = 50 000 samples, first 5000 burn removed, second 5000 burn included
depoch = my_backend.get_samples(pars=["p", "ll", "lp"]) #return log-likelihood and log-prior too
depoch["ll"].shape
# getting cov and statistic from file
fn_epoch0 = my_backend._get_fname(1)
with h5py.File(fn_epoch0) as fid:
    print(f"cov shape: {fid["epoch_0"]["covs"].shape}")
    print(f"stats info: {fid["epoch_0"]["stats"].keys()}")

Cyclic parameters#

sampling =  SamplingMH(
            space=Box(dim=dim, size=10.0, wrapped=[0,1]), #wrapped: list of periodic dimensions indexes.
            nwalker=nwalker,
            loglik=d.Normal(dim=dim).leval,
            temps=temperature_ladder, 
            logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,
        )
sampler = sampling.get_sampler()