Source code for jexplore.tools.diagnostic

"""Diagnostic tool for assessing the convergence of the MCMC
sampling.
"""

import emcee
import jax.numpy as jnp


[docs] def gelman_rubin_statistic(samples, ntemp=1): """Return square root of Gelman-Rubin statistic R. Assess the convergence by comparing the variance between the chains and the variance in the chains. :param array samples: (nchain, dim, npoints) array with nchain equal to nwalker if ntemp=1 (default) or nchain equal nwalker x ntemp. :param int ntemp: return R per temperature is ntemp > 1 :return: (ntemp, ndim) array of the statistic for each dimension (and temperature) :rtype: array """ ndim, nsamples = samples.shape[1:] if ntemp > 1: nwalker = samples.shape[0] // ntemp _samples = samples.reshape(nwalker, ntemp, ndim, nsamples) _samples = jnp.permute_dims(_samples, [3, 0, 1, 2]) ll, jj, _, _ = _samples.shape else: _samples = jnp.permute_dims(_samples, [2, 0, 1]) ll, jj, __ = _samples.shape c_mean = jnp.mean(_samples, axis=0) # mean value of chain i g_mean = jnp.mean(c_mean, axis=0) # global mean within = jnp.var(_samples, axis=0, ddof=1) # intra chain variance between = ( ll / (jj - 1) * jnp.sum((c_mean - g_mean) ** 2, axis=0) ) # variance between chains_eq w = jnp.mean(within, axis=0) # averaged variance r = ((ll - 1) / ll) * w + between * (1 / ll) r = r / w return jnp.sqrt(r)
[docs] def auto_correlation_length(samples): """Return auto correlation length from emcee. :param array samples: (nchain, dim, npoints) array with nchain equal to nwalker :return: :rtype: array """ _samples = jnp.permute_dims(samples, [2, 0, 1]) # nsteps, nwalker, nparam try: res = emcee.autocorr.integrated_time(_samples) except emcee.AutocorrError: res = np.nan return res