{ "cells": [ { "cell_type": "markdown", "id": "af8f17bc-9e2b-4af3-81b2-755fc9aa2cc9", "metadata": {}, "source": [ "# Metropolis Hasting MCMC\n", "\n", "In this tutorials, we sample a multi-variate Gaussian with MCMC based on Metropolis Hasting algorithm. \n", "\n", "We illustrate how to use ensemble sampling as well as parallel tempering. " ] }, { "cell_type": "code", "execution_count": null, "id": "9a2a1ec7-d8fd-4e4f-8693-d8119cc5a6a0", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "import seaborn as sns\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import jexplore.tools.distributions as d\n", "from jexplore.sampling import SamplingMH, Box\n", "from jexplore.tools.diagnostic import auto_correlation_length, gelman_rubin_statistic\n", "\n", "%reload_ext autoreload\n", "%autoreload 2 " ] }, { "cell_type": "markdown", "id": "fa463d23-651e-4577-9d34-8e7416516f09", "metadata": {}, "source": [ "## Preliminary definitions\n", "\n", "Here we define our sampling ingredients: \n", "\n", "- We make use of a pre-defined Gaussian log-likelihood from `jexplore` toolbox ;\n", "- We define the number of walkers of the ensemble and the temperature ladder\n", "- We make use the built-in `stretch` move as a proposal, which is proposed as default" ] }, { "cell_type": "code", "execution_count": null, "id": "832266ac-1fad-4406-a722-d33914a76971", "metadata": {}, "outputs": [], "source": [ "dim = 2\n", "nwalker = 10\n", "temperature_ladder = jnp.arange(1, 10, 3) \n", "ntemp = len(temperature_ladder)" ] }, { "cell_type": "code", "execution_count": null, "id": "3270bf43-3c0a-425c-af50-b0cae849c4d7", "metadata": {}, "outputs": [], "source": [ "sampling = SamplingMH(\n", " space=Box(dim=dim, size=10.0),\n", " nwalker=nwalker,\n", " loglik=d.Normal(dim=dim).leval,\n", " temps=temperature_ladder, \n", " logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,\n", " )\n", "sampler = sampling.get_sampler()" ] }, { "cell_type": "markdown", "id": "ccd321b6-4bad-43cc-aa1f-ddb8a28d7577", "metadata": {}, "source": [ "## Run MCMC\n", "\n", "We start from a random draw of the normal distribution for each walker and temperature. \n", "\n", "The `JaxSampler` starts from and returns `Epoch` : an encapsulation of actual samples and their meta data, like covariance, log-likelihood, sampling statistics ans so on. \n", "\n", "In this example, we run for one epoch (`nepoch=1`), one has the choice of:\n", "- running several epoch at one time by setting `nepoch` > 1\n", "- pausing and resuming MCMC from the returned epoch" ] }, { "cell_type": "code", "execution_count": null, "id": "2366a443-d3b3-490c-8a23-df0a940b24c2", "metadata": {}, "outputs": [], "source": [ "# drawing a starting point\n", "key = jax.random.key(42)\n", "p0 = d.Normal(dim=dim).sample(key, shape=(nwalker * ntemp,))[1]\n", "epoch = sampling.get_epoch(p0)\n", "\n", "niters = 5_000\n", "burn = 500\n", "epoch = sampler.run(epoch, niters=niters + burn, nepoch=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1301b7a-bf8d-4093-a452-551ceec652b8", "metadata": {}, "outputs": [], "source": [ "chain = epoch.samples.p[:, :, burn:] \n", "chain.shape" ] }, { "cell_type": "markdown", "id": "5201f6c8-d38a-477c-ac69-efaa679ff731", "metadata": {}, "source": [ "## Convergence diagnostics" ] }, { "cell_type": "code", "execution_count": null, "id": "9faa8f32-b782-4962-802f-e4ca68267e4b", "metadata": {}, "outputs": [], "source": [ "gr = np.max(np.max(np.abs(np.sqrt(gelman_rubin_statistic(chain, ntemp=ntemp)) - 1)))\n", "acl = np.max(auto_correlation_length(chain))\n", "print(f\"Gelman Rubin ratio: {gr}. Autocorrelation length: {acl}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cbc3ba1b-f95c-4eff-b912-d7c6c1294398", "metadata": {}, "outputs": [], "source": [ "_pdf = np.moveaxis(chain, [1], [0]).reshape(dim, -1)\n", "sns.pairplot(pd.DataFrame(_pdf.T), corner=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "f0a2b008-56bd-4a0e-baf2-cde0e6c3d98f", "metadata": {}, "outputs": [], "source": [ "## Acceptance rates\n", "N = nwalker * ntemp\n", "print(f\"stretch acceptance rate of cold chains: {np.mean(epoch.stats.stats[1,np.arange(0,N,ntemp)] / niters)}\")\n", "print(f\"stretch acceptance rate of hot chains: {np.mean(epoch.stats.stats[1,np.arange(ntemp,N,ntemp)] / niters)}\")\n", "\n", "print(f\"swap rate between t0-t1: {np.mean(epoch.stats.stats[0,np.arange(1,N,ntemp)] / niters)}\")\n", "print(f\"swap rate between t1-t2: {np.mean(epoch.stats.stats[0,np.arange(2,N,ntemp)] / niters)}\")" ] }, { "cell_type": "markdown", "id": "06a0f532-fd42-42b4-9796-ee5ccbe02fc4", "metadata": {}, "source": [ "## Going further" ] }, { "cell_type": "code", "execution_count": null, "id": "56914ed7-5d3e-47a2-aab1-3ba29c9428c9", "metadata": {}, "outputs": [], "source": [ "import h5py \n", "from jexplore.sampler import Steps\n", "from jexplore.steps import DEStep, Stretch, TSwap\n", "from jexplore.steps.rwalk import GaussianRandomWalk, StudentTRandomWalk\n", "\n", "from jexplore.backends import DefaultBackend" ] }, { "cell_type": "markdown", "id": "3146a395-ba03-47a6-be5a-a8005ff1f723", "metadata": {}, "source": [ "### More proposals" ] }, { "cell_type": "code", "execution_count": null, "id": "51e49a77-0ee7-4788-872e-ef98a2dff2fc", "metadata": {}, "outputs": [], "source": [ "sampler = sampling.get_sampler(steps=[\n", " {TSwap(permute=True).builder: 1.0}, # proposal builder and its weight\n", " {Stretch(permute=True).builder: 1.0},\n", " {GaussianRandomWalk().builder: 1.0}, \n", " {StudentTRandomWalk().builder: 1.0}, \n", " {DEStep(permute=True).builder: 1.0}, \n", "])" ] }, { "cell_type": "markdown", "id": "562fbfb2-53ea-46b5-8f04-5b149b7590da", "metadata": {}, "source": [ "### Backend customization\n", "\n", "We explicitely define the backend configuration and use it to retrieve all the informations stored on disk. \n", "\n", "As an illustration, we run for 2 epochs, and use one h5 file per epoch. " ] }, { "cell_type": "code", "execution_count": null, "id": "10755597-1d13-457d-a376-358adca567c5", "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"outdir\": \"./\", # where to save h5 files\n", " \"outfile_fmt\": \"epochs_%d.h5\", # h5 file template name\n", " \"epochs_per_file\": 1,\n", " \"inmem_epochs\": 0, \n", " \"save_stats\": True,\n", "}\n", "burn = 500 # number of samples which won't be registered\n", "my_backend = DefaultBackend(burn=burn, config=config)\n", "sampler = sampling.get_sampler(backend=my_backend)\n", "epoch = sampler.run(epoch, niters=niters + burn, nepoch=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "42d8dd76-2b1d-4cb3-83bf-9dba79475228", "metadata": {}, "outputs": [], "source": [ "stacked_chain = my_backend.get_samples()[\"p\"]\n", "stacked_chain.shape # 2 epoch of niters = 50 000 samples, first 5000 burn removed, second 5000 burn included" ] }, { "cell_type": "code", "execution_count": null, "id": "50ebda0d-4b26-43bd-b371-4cda402a4a5e", "metadata": {}, "outputs": [], "source": [ "depoch = my_backend.get_samples(pars=[\"p\", \"ll\", \"lp\"]) #return log-likelihood and log-prior too\n", "depoch[\"ll\"].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "8ec3f56a-a811-4e79-a65b-514a93e7159f", "metadata": {}, "outputs": [], "source": [ "# getting cov and statistic from file\n", "fn_epoch0 = my_backend._get_fname(1)\n", "with h5py.File(fn_epoch0) as fid:\n", " print(f\"cov shape: {fid[\"epoch_0\"][\"covs\"].shape}\")\n", " print(f\"stats info: {fid[\"epoch_0\"][\"stats\"].keys()}\")" ] }, { "cell_type": "markdown", "id": "7f223813-de38-4e07-948e-42973da30e65", "metadata": {}, "source": [ "### Cyclic parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "0c2a0471-58e6-4acb-a204-96f58ea2afc4", "metadata": {}, "outputs": [], "source": [ "sampling = SamplingMH(\n", " space=Box(dim=dim, size=10.0, wrapped=[0,1]), #wrapped: list of periodic dimensions indexes.\n", " nwalker=nwalker,\n", " loglik=d.Normal(dim=dim).leval,\n", " temps=temperature_ladder, \n", " logprior=d.Uniform(dim=dim, minval=-10, maxval=10).leval,\n", " )\n", "sampler = sampling.get_sampler()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }